Skip to content

Commit ba47c13

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Skip load metadata tensor (#4830)
Summary: X-link: meta-pytorch/torchrec#3359 Pull Request resolved: #4830 X-link: facebookresearch/FBGEMM#1856 The metadata tensor is newly added for kvzch table. Some old checkpoints may not have this fqn. Directly load old checkpoint can cause fqn missing error. This diff try to skip init metadata tensor at load checkpoint func. Metadata tensor is not used in training, so it is okay to skip load. It will be created during saving checkpoint. Reviewed By: steven1327, emlin Differential Revision: D81811024 fbshipit-source-id: edd731b40c6a843b338cc0c9a7f4ffb55000b706
1 parent 9f81399 commit ba47c13

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3065,6 +3065,7 @@ def split_embedding_weights(
30653065
bucket_sorted_id_splits = [] if self.kv_zch_params else None
30663066
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
30673067
metadata_splits = [] if self.kv_zch_params else None
3068+
skip_metadata = False
30683069

30693070
table_offset = 0
30703071
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
@@ -3132,18 +3133,17 @@ def split_embedding_weights(
31323133
device=torch.device("cpu"),
31333134
dtype=torch.int64,
31343135
)
3135-
metadata_tensor = torch.zeros(
3136-
(self.local_weight_counts[i], 1),
3137-
device=torch.device("cpu"),
3138-
dtype=torch.int64,
3139-
)
3136+
skip_metadata = True
31403137

31413138
# self.local_weight_counts[i] = 0 # Reset the count
31423139

31433140
# pyre-ignore [16] bucket_sorted_id_splits is not None
31443141
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
31453142
active_id_cnt_per_bucket_split.append(bucket_t)
3146-
metadata_splits.append(metadata_tensor)
3143+
if skip_metadata:
3144+
metadata_splits = None
3145+
else:
3146+
metadata_splits.append(metadata_tensor)
31473147

31483148
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
31493149
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,16 @@ def test_apply_kv_state_dict(
811811

812812
# pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `__iter__`.
813813
emb2.local_weight_counts = [ids.numel() for ids in bucket_asc_ids_list]
814+
815+
(
816+
_,
817+
_,
818+
_,
819+
metadata_list,
820+
) = emb2.split_embedding_weights(no_snapshot=False, should_flush=True)
821+
822+
self.assertTrue(metadata_list is None)
823+
814824
emb2.enable_load_state_dict_mode()
815825
self.assertIsNotNone(emb2._cached_kvzch_data)
816826
for i, _ in enumerate(emb.embedding_specs):
@@ -844,12 +854,14 @@ def test_apply_kv_state_dict(
844854
emb_state_dict_list2,
845855
bucket_asc_ids_list2,
846856
num_active_id_per_bucket_list2,
847-
_,
857+
metadata_list2,
848858
) = emb2.split_embedding_weights(no_snapshot=False, should_flush=True)
849859
split_optimizer_states2 = emb2.split_optimizer_states(
850860
bucket_asc_ids_list2, no_snapshot=False, should_flush=True
851861
)
852862

863+
self.assertTrue(metadata_list2 is not None)
864+
853865
for t in range(len(emb.embedding_specs)):
854866
sorted_ids = torch.sort(bucket_asc_ids_list[t].flatten())
855867
sorted_ids2 = torch.sort(bucket_asc_ids_list2[t].flatten())
@@ -881,6 +893,10 @@ def test_apply_kv_state_dict(
881893
rtol=tolerance,
882894
)
883895

896+
self.assertTrue(
897+
metadata_list2[t].size(0) == bucket_asc_ids_list2[t].size(0)
898+
)
899+
884900
def _check_raw_embedding_stream_call_counts(
885901
self,
886902
mock_raw_embedding_stream: unittest.mock.Mock,

0 commit comments

Comments
 (0)