-
Notifications
You must be signed in to change notification settings - Fork 9
Description
The current Jax tree building implementation constructs tree using the events from the tracefile. The categories for different node types include 'Unknown', 'kernel', 'cpu_op', 'memcpy', 'python function' as of now. I believe that we need to add more nodes/events to current Jax trace2tree to make it a true tree and perform richer analysis.
1. Proposing to add a new event category “hlo_op” and corresponding nodes to the tree:
Hlo_op is key representation level in Jax/XLA. It can be used to get information such as input size, output size, data type, operation type, data dependency and so on. In several cases, single hlo_op might have multiple instances-- attention operation is lowered to custom_call.x.0, and is called multiple times by the model as it executes the same layer multiple times. Further, a single instance of this hlo_op can result in multiple kernels.
By introducing new nodes representing hlo_op level in the tree, we can perform more accurate perf analysis, gpu_op categorization, capture dependency between GPU kernels, and so on.
Example for attention op:
Currently, only key kernels are categorized as FA kernel and are used to get performance (runtime, TFLOP/sec/device)
name:kernel_func, dur:16488.654, pid:8, cat:kernel, gpu_kernel_op_cat:FA V3, hlo_module:jit_train_step, hlo_op:custom-call.63.0
However, if we have augmented the tree with hlo_op nodes, subtree traversal on the hlo_op node can get us more information about the attention call:
custom-call.63.0 dur: 17176.08 cat:hlo_op hlo_op_event_type=te_fused_attn_backward_ffi
|_name:FillBuffer, dur:79.15, pid:8, cat:kernel, gpu_cat:Conv
|_name:FillBuffer, dur:193.28, pid:8, cat:kernel, gpu_cat:Conv
|_name:void ck_tile::kentry<64, 2, ck_tile::FmhaBwdOGradDotOKernel<, dur:414.98, pid:8, cat:kernel, gpu_cat:FA BWD
|_name:kernel_func, dur:16488.65, pid:8, cat:kernel, gpu_cat:FA V3
The above hlo_op subtree traversal will give true runtime (and idle time) for attention op and performance.
2. Add nodes for different levels of framework name scope call stack.
Adding nodes for different framework call stack levels and linking the nodes will help us with framework operation level analysis (e.g., layer_0.mlp). Currently, the framework callstack information is in the metadata of gpu_kernel. There are some cases where meta events (tid>> 10000) and stackframe information can be used to improve this callstack information. So this task can be divided into two parts:
(1) merge meta events and stackframe information from the tracefile and improve the callstack information
(2) parse the callstack information to augment Jax tree by adding and linking new nodes.
example:
We can add nodes for all levels in the framework name scope call stack here. This makes is possible to get a subtree for self_attention which can be used to get all GPU events launched for self_attention.
└── jit(train_step)
└── jit(main)
└── transpose(jvp(Transformer))
└── decoder
└── while
└── body
└── checkpoint
└── layers
└── self_attention
└── attention_op
└── attention_op.apply_attention
└── attention_op.cudnn_flash_attention
└── DotProductAttention_0
└── _FusedDotProductAttention_0
└── custom_partitioning
hlo_module:jit_train_step
hlo_op:custom-call.63.0