Skip to content

Conversation

@asfiyab-nvidia
Copy link
Owner

  1. Scale reinitialization fixed my wrapping module init in fp8_autocast and providing recipe
  2. Update scale_factors passed to Linear, LayernormLinear, LayernormMLP layers by providing num_gemms param and providing 2 scale values per gemm.
  3. Add configurable scale initialization for MHA and Transformer Layer

asfiyab-nvidia and others added 12 commits January 4, 2023 21:32
* Add TorchScript Operators
* Add symbolic methods to ONNX exporter
* Add tests for the ONNX export

Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
* Increase layernorm FP16 threshold
* Normalize onnx file names: _ separates configs; - separates words in a single config
* Add get_attn_mask_str and fix mask string
* Add missing ONNX files
* Moved generated ONNX files to tests/gen_onnx_models/

Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
1. remove List import for pylint failure
2. address comments: remove state tensors from GPU
3. address comments: Update reverse_map_dtype function and add to namespace

Signed-off-by: Asfiya Baig <[email protected]>
Signed-off-by: Asfiya Baig <[email protected]>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this assert because it checks a TE predicate that should not affect the export process.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove num_gemms because we don't need/use it.
Please change the typing hint for scales to List.

Suggested change
def set_layer_scale(module: torch.nn.Module, scales: float, num_gemms: int=1):
def set_layer_scale(module: torch.nn.Module, scales: List[float]):

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_gemms is required for the fp8_init() call. For LayernormMLP specifically, it is set to 2. (see line below)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad - I missed that - ignore my comment

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scale_factor: list,
scale_factor: List[float],

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("scale_factor", [[448, 448]])
@pytest.mark.parametrize("scale_factors", [[448, 448]])

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("scale_factor", [[448, 448]])
@pytest.mark.parametrize("scale_factors", [[448, 448]])

Made it explicitly plural

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
set_layer_scale(model, scale_factor, num_gemms=2)
set_layer_scale(model, scale_factors)

Comment on lines 781 to 784
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scale_factor_qkv: list=[448, 448],
scale_factor_query: list=[112, 112],
scale_factor_kv: list=[224, 224],
scale_factor_proj: list=[448, 448]
scale_factor_qkv: List[float]=[448, 448],
scale_factor_query: List[float]=[112, 112],
scale_factor_kv: List[float]=[224, 224],
scale_factor_proj: List[float]=[448, 448]

Comment on lines 845 to 848
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scale_factor_qkv: list,
scale_factor_query: list,
scale_factor_kv: list,
scale_factor_proj: list,
scale_factors_qkv: List[float],
scale_factors_query: List[float],
scale_factors_kv: List[float],
scale_factors_proj: List[float],

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scales_layernorm_mlp: list=[224, 224, 448, 448]):
scales_layernorm_mlp: List[float]=[224, 224, 448, 448]):

Comment on lines 945 to 949
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scale_factor_qkv: list,
scale_factor_query: list,
scale_factor_kv: list,
scale_factor_proj: list,
scale_factor_layernorm_mlp: list,
scale_factor_qkv: List[float],
scale_factor_query: List[float],
scale_factor_kv: List[float],
scale_factor_proj: List[float],
scale_factor_layernorm_mlp: List[float],

Signed-off-by: Asfiya Baig <[email protected]>
1. replace variable scale_factor with scale_factors
2. Update type hints for scale_factors to be List[float]
3. Remove use of num_gemms param and add amax_history assignment

Signed-off-by: Asfiya Baig <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants