-
Notifications
You must be signed in to change notification settings - Fork 0
Eliminate scale re-initialization #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-onnx-export-support
Are you sure you want to change the base?
Conversation
* Add TorchScript Operators * Add symbolic methods to ONNX exporter * Add tests for the ONNX export Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
* Increase layernorm FP16 threshold * Normalize onnx file names: _ separates configs; - separates words in a single config * Add get_attn_mask_str and fix mask string * Add missing ONNX files * Moved generated ONNX files to tests/gen_onnx_models/ Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
1. remove List import for pylint failure 2. address comments: remove state tensors from GPU 3. address comments: Update reverse_map_dtype function and add to namespace Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this assert because it checks a TE predicate that should not affect the export process.
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove num_gemms because we don't need/use it.
Please change the typing hint for scales to List.
| def set_layer_scale(module: torch.nn.Module, scales: float, num_gemms: int=1): | |
| def set_layer_scale(module: torch.nn.Module, scales: List[float]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_gemms is required for the fp8_init() call. For LayernormMLP specifically, it is set to 2. (see line below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad - I missed that - ignore my comment
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| scale_factor: list, | |
| scale_factor: List[float], |
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @pytest.mark.parametrize("scale_factor", [[448, 448]]) | |
| @pytest.mark.parametrize("scale_factors", [[448, 448]]) |
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @pytest.mark.parametrize("scale_factor", [[448, 448]]) | |
| @pytest.mark.parametrize("scale_factors", [[448, 448]]) |
Made it explicitly plural
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| set_layer_scale(model, scale_factor, num_gemms=2) | |
| set_layer_scale(model, scale_factors) |
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| scale_factor_qkv: list=[448, 448], | |
| scale_factor_query: list=[112, 112], | |
| scale_factor_kv: list=[224, 224], | |
| scale_factor_proj: list=[448, 448] | |
| scale_factor_qkv: List[float]=[448, 448], | |
| scale_factor_query: List[float]=[112, 112], | |
| scale_factor_kv: List[float]=[224, 224], | |
| scale_factor_proj: List[float]=[448, 448] |
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| scale_factor_qkv: list, | |
| scale_factor_query: list, | |
| scale_factor_kv: list, | |
| scale_factor_proj: list, | |
| scale_factors_qkv: List[float], | |
| scale_factors_query: List[float], | |
| scale_factors_kv: List[float], | |
| scale_factors_proj: List[float], |
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| scales_layernorm_mlp: list=[224, 224, 448, 448]): | |
| scales_layernorm_mlp: List[float]=[224, 224, 448, 448]): |
tests/test_onnx_export.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| scale_factor_qkv: list, | |
| scale_factor_query: list, | |
| scale_factor_kv: list, | |
| scale_factor_proj: list, | |
| scale_factor_layernorm_mlp: list, | |
| scale_factor_qkv: List[float], | |
| scale_factor_query: List[float], | |
| scale_factor_kv: List[float], | |
| scale_factor_proj: List[float], | |
| scale_factor_layernorm_mlp: List[float], |
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
1. replace variable scale_factor with scale_factors 2. Update type hints for scale_factors to be List[float] 3. Remove use of num_gemms param and add amax_history assignment Signed-off-by: Asfiya Baig <[email protected]>
325fe62 to
e83246e
Compare
119a0ec to
ab4410f
Compare
num_gemmsparam and providing 2 scale values per gemm.