Skip to content

Commit 6ab0afc

Browse files
farhadrghdorotat-nvcspadestrvachovnvdreidenbach
authored
Update EVO2 tests according to Hyena arch changes (#798)
### Description NVIDIA-NeMo/NeMo#12856 introduces code reduction and perf improvements including standardizing input/output shapes for Hyena operators and consequentially reducing rearrangement overhead. This PR updates the EVO2 test to comply with those changes, ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels: - [SKIP_CI](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#skip_ci) - Skip all continuous integration tests - [INCLUDE_NOTEBOOKS_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_notebooks_tests) - Execute notebook validation tests in pytest - [INCLUDE_SLOW_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_slow_tests) - Execute tests labelled as slow in pytest for extensive testing > [!NOTE] > By default, the notebooks validation tests are skipped unless explicitly enabled. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. * If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) * If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Usage <!--- How does a user interact with the changed code --> ```python TODO: Add code snippet ``` ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully --------- Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com> Signed-off-by: Cory Ye <cye@nvidia.com> Signed-off-by: cspades <cory0ye@gmail.com> Signed-off-by: Timur Rvachov <trvachov@nvidia.com> Signed-off-by: Danny <dreidenbach@nvidia.com> Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com> Signed-off-by: nvdreidenbach <97637601+nvdreidenbach@users.noreply.github.com> Signed-off-by: Peter St. John <pstjohn@nvidia.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: Polina Binder <pbinder@nvidia.com> Signed-off-by: polinabinder1 <pbinder@nvidia.com> Signed-off-by: dorotat <dorotat@nvidia.com> Signed-off-by: Truong Nguyen <tgnguyen@nvidia.com> Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com> Signed-off-by: Timur Rvachov <120140748+trvachov@users.noreply.github.com> Signed-off-by: Steven <skothenhill@nvidia.com> Co-authored-by: Dorota Toczydlowska <115542912+dorotat-nv@users.noreply.github.com> Co-authored-by: Cory Ye <44509866+cspades@users.noreply.github.com> Co-authored-by: Timur Rvachov <120140748+trvachov@users.noreply.github.com> Co-authored-by: nvdreidenbach <97637601+nvdreidenbach@users.noreply.github.com> Co-authored-by: Steven Kothen-Hill <148821680+skothenhill-nv@users.noreply.github.com> Co-authored-by: Peter St. John <pstjohn@nvidia.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: polinabinder1 <pbinder@nvidia.com> Co-authored-by: Truong Nguyen <tgnguyen@nvidia.com> Co-authored-by: jomitchellnv <148147880+jomitchellnv@users.noreply.github.com> Co-authored-by: lvojtku <lvojtku@nvidia.com>
1 parent 3936231 commit 6ab0afc

File tree

11 files changed

+62
-61
lines changed

11 files changed

+62
-61
lines changed

3rdparty/Megatron-LM

Submodule Megatron-LM updated 969 files

3rdparty/NeMo

Submodule NeMo updated from b685967 to 42d2b55

sub-packages/bionemo-amplify/tests/bionemo/amplify/test_hf_rotary.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import torch
17-
from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
17+
from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd
1818
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
1919
from transformers import AutoConfig
2020

@@ -47,8 +47,20 @@ def test_rope_embeddings():
4747
seq_len_interpolation_factor=nemo_config.seq_len_interpolation_factor,
4848
)
4949
rotary_pos_emb = rotary_pos_layer(q.shape[1])
50-
q_post_nemo = apply_rotary_pos_emb(q.transpose(0, 1).cuda(), rotary_pos_emb.cuda(), config=nemo_config).cpu()
51-
k_post_nemo = apply_rotary_pos_emb(k.transpose(0, 1).cuda(), rotary_pos_emb.cuda(), config=nemo_config).cpu()
50+
# Note: Use the backend implementation of the RoPE to avoid
51+
# getting or instantiating a CP process group.
52+
q_post_nemo = _apply_rotary_pos_emb_bshd(
53+
q.transpose(0, 1).cuda(),
54+
rotary_pos_emb.cuda(),
55+
rotary_interleaved=nemo_config.rotary_interleaved,
56+
multi_latent_attention=nemo_config.multi_latent_attention,
57+
).cpu()
58+
k_post_nemo = _apply_rotary_pos_emb_bshd(
59+
k.transpose(0, 1).cuda(),
60+
rotary_pos_emb.cuda(),
61+
rotary_interleaved=nemo_config.rotary_interleaved,
62+
multi_latent_attention=nemo_config.multi_latent_attention,
63+
).cpu()
5264

5365
torch.testing.assert_close(q_post, q_post_nemo.transpose(0, 1))
5466
torch.testing.assert_close(k_post, k_post_nemo.transpose(0, 1))

sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_train_esm2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def test_main_runs(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inpu
327327
event_files = list(log_dir.rglob("events.out.tfevents*"))
328328
assert event_files, f"No TensorBoard event files found under {log_dir}"
329329
assert "val_ppl" in trainer.logged_metrics # validation logging on by default
330-
assert "tflops_per_sec_per_gpu" in trainer.logged_metrics # ensuring that tflops logger can be added
330+
assert "TFLOPS_per_GPU" in trainer.logged_metrics # ensuring that tflops logger can be added
331331
assert "train_step_timing in s" in trainer.logged_metrics
332332

333333

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
157157
return forward_out
158158
# Reminder: the model's predictions for input i land at output i+1. To get everything to align, we prepend the
159159
# EOS token to the input sequences and take the outputs for all but the first token.
160-
forward_out_tp_gathered = _gather_along_last_dim(forward_out)
160+
forward_out_tp_gathered = _gather_along_last_dim(
161+
forward_out, group=parallel_state.get_tensor_model_parallel_group()
162+
)
161163
# else:
162164
# forward_out_tp_gathered = _collect_into_dim(forward_out, dim=-1)
163165
forward_out_gathered = _gather_along_cp_dim(forward_out_tp_gathered)

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_train_evo2_stops(tmp_path):
146146
)
147147

148148
assert "reduced_train_loss" in trainer.logged_metrics # validation logging on by default
149-
assert "tflops_per_sec_per_gpu" in trainer.logged_metrics # ensuring that tflops logger can be added
149+
assert "TFLOPS_per_GPU" in trainer.logged_metrics # ensuring that tflops logger can be added
150150
assert "train_step_timing in s" in trainer.logged_metrics
151151

152152

sub-packages/bionemo-evo2/tests/bionemo/evo2/test_hyena_operators.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ def test_gpu_forward(self, operator: ParallelHyenaOperator):
6868
g = operator.num_groups
6969
dg = operator.group_dim
7070

71-
x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
72-
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
73-
v = torch.ones((batch_size, seq_len, g, dg), device=device)
71+
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
72+
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
73+
v = torch.ones((batch_size, (g * dg), seq_len), device=device)
7474

7575
output = operator(x1, x2, v)
7676
assert output.shape[0] == batch_size
77-
assert output.shape[1] == seq_len
78-
assert output.shape[2] == operator.hidden_size
77+
assert output.shape[1] == operator.hidden_size
78+
assert output.shape[2] == seq_len
7979

8080

8181
class TestParallelShortHyenaOperator:
@@ -89,7 +89,6 @@ def operator(self, transformer_config: TransformerConfig, hyena_config: HyenaCon
8989
init_method="small_init",
9090
short_conv_class=ParallelCausalDepthwiseConv1d,
9191
use_fast_causal_conv=False,
92-
is_mlp=False,
9392
local_init=False,
9493
use_conv_bias=False,
9594
)
@@ -109,14 +108,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
109108
g = operator.num_groups
110109
dg = operator.group_dim
111110

112-
x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
113-
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
114-
v = torch.ones((batch_size, seq_len, g, dg), device=device)
111+
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
112+
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
113+
v = torch.ones((batch_size, (g * dg), seq_len), device=device)
115114

116115
output = operator(x1, x2, v)
117116
assert output.shape[0] == batch_size
118-
assert output.shape[1] == seq_len
119-
assert output.shape[2] == operator.hidden_size
117+
assert output.shape[1] == operator.hidden_size
118+
assert output.shape[2] == seq_len
120119

121120

122121
class TestParallelShortHyenaOperatorWithConvBias:
@@ -130,7 +129,6 @@ def operator(self, transformer_config: TransformerConfig, hyena_config: HyenaCon
130129
init_method="small_init",
131130
short_conv_class=ParallelCausalDepthwiseConv1d,
132131
use_fast_causal_conv=False,
133-
is_mlp=False,
134132
local_init=False,
135133
use_conv_bias=True,
136134
)
@@ -150,14 +148,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
150148
g = operator.num_groups
151149
dg = operator.group_dim
152150

153-
x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
154-
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
155-
v = torch.ones((batch_size, seq_len, g, dg), device=device)
151+
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
152+
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
153+
v = torch.ones((batch_size, (g * dg), seq_len), device=device)
156154

157155
output = operator(x1, x2, v)
158156
assert output.shape[0] == batch_size
159-
assert output.shape[1] == seq_len
160-
assert output.shape[2] == operator.hidden_size
157+
assert output.shape[1] == operator.hidden_size
158+
assert output.shape[2] == seq_len
161159

162160

163161
class TestParallelCausalDepthwiseConv1d:

sub-packages/bionemo-geneformer/src/bionemo/geneformer/model/finetune_token_regressor.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@
2424
from nemo.collections.llm.peft.lora import LoRA, LoRALinear
2525
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter
2626
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
27-
from nemo.lightning.megatron_parallel import (
28-
masked_token_loss,
29-
masked_token_loss_context_parallel,
30-
)
27+
from nemo.lightning.megatron_parallel import masked_token_loss
3128
from torch import Tensor, nn
3229

3330
from bionemo.llm.model.biobert.model import BioBertConfig, BioBertOutput, MegatronBioBertModel
@@ -102,17 +99,7 @@ def forward(
10299
# TODO(@jstjohn) also handle different output keys, like the sequence loss.
103100

104101
cp_size = parallel_state.get_context_parallel_world_size()
105-
if cp_size == 1:
106-
# reduce the loss across the micro batch
107-
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
108-
else:
109-
# reduce the loss across the micro batch.
110-
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
111-
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
112-
# other necessary keys to the batch. Thanks!
113-
loss_for_microbatch = masked_token_loss_context_parallel(
114-
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
115-
)
102+
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)
116103

117104
# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
118105
# reducing the loss across the data parallel group.

sub-packages/bionemo-llm/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ dependencies = [
1616
# external
1717
'lightning>=2.2.1',
1818
'megatron-core',
19-
'nemo_toolkit[nlp]>=2.2.1',
19+
'nemo_toolkit[nlp,eval]>=2.2.1',
2020
'nemo-run',
2121
'hatchling',
2222
]

sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from nemo.lightning.megatron_parallel import (
2323
MegatronLossReduction,
2424
masked_token_loss,
25-
masked_token_loss_context_parallel,
2625
)
2726
from torch import Tensor
2827

@@ -179,24 +178,17 @@ def forward(
179178

180179
# TODO(@jstjohn) also handle different output keys, like the sequence loss.
181180

182-
# compute loss
181+
# Compute loss over "valid" tokens in the microbatch, i.e. the non-masked tokens.
182+
# The loss is not normalized, only potentially reduced via torch.distributed.ReduceOp.SUM
183+
# across the context parallel process group, so you need to divide by the number
184+
# of non-masked tokens (loss_mask.sum()) to compute the mean reduced loss per token.
183185
cp_size = parallel_state.get_context_parallel_world_size()
184-
if cp_size == 1:
185-
# reduce the loss across the micro batch per valid token
186-
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
187-
else:
188-
# reduce the loss across the micro batch per valid token.
189-
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
190-
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
191-
# other necessary keys to the batch. Thanks!
192-
loss_for_microbatch = masked_token_loss_context_parallel(
193-
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
194-
)
186+
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size=cp_size)
187+
num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
195188

196189
# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
197190
# reducing the loss across the data parallel group.
198191
if self.validation_step and not self.val_drop_last:
199-
num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
200192
if loss_for_microbatch.isnan():
201193
# TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
202194
# to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
@@ -205,9 +197,8 @@ def forward(
205197
raise ValueError("Got NaN loss with non-empty input")
206198
loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
207199
else:
208-
loss_sum_for_microbatch = (
209-
num_valid_tokens_in_microbatch * loss_for_microbatch
210-
) # sum over all valid tokens
200+
# The reduced loss is already the sum of all losses from masked_token_loss().
201+
loss_sum_for_microbatch = loss_for_microbatch
211202

212203
# In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
213204
loss_sum_and_microbatch_size_all_gpu = torch.cat(
@@ -216,17 +207,28 @@ def forward(
216207
Tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
217208
]
218209
)
210+
211+
# Reduce the loss sum across the data parallel group to get the total loss
212+
# for all data parallel / distributed microbatches.
219213
torch.distributed.all_reduce(
220214
loss_sum_and_microbatch_size_all_gpu,
221215
group=parallel_state.get_data_parallel_group(),
222216
op=torch.distributed.ReduceOp.SUM,
223217
)
218+
219+
# Return the loss tensor multiplied by the context parallel size,
220+
# and the data & context parallel reduced loss sum.
224221
return loss_for_microbatch * cp_size, {
225222
"loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
226223
}
227224

228-
# average the losses across the data parallel group, but also return the unreduced loss
229-
reduced_loss = average_losses_across_data_parallel_group([loss_for_microbatch])
225+
# Return the loss tensor multiplied by the context parallel size, as well as
226+
# the data-parallel averaged loss, i.e. the loss divided by the DP size.
227+
# Normalize the loss by the number of "valid" tokens, because masked_token_loss
228+
# no longer does this normalization, and BioNeMo losses expect this normalization.
229+
reduced_loss = (
230+
average_losses_across_data_parallel_group([loss_for_microbatch]) / num_valid_tokens_in_microbatch
231+
)
230232
return loss_for_microbatch * cp_size, {"avg": reduced_loss}
231233

232234

0 commit comments

Comments
 (0)