Skip to content

Commit aecb5d3

Browse files
authored
replace pcikle.loads with json.loads (#15232)
Signed-off-by: stevehuang52 <[email protected]>
1 parent 412ab81 commit aecb5d3

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
"""Utilities for generating text."""
1818

19-
import pickle
19+
import json
2020
import sys
2121
from collections.abc import Iterable
2222
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
@@ -129,7 +129,7 @@ def send_generate_info(
129129

130130
# send end strings
131131
string_tensor = torch.as_tensor(
132-
np.frombuffer(pickle.dumps(end_strings), dtype=np.int8), device=torch.cuda.current_device()
132+
np.frombuffer(json.dumps(end_strings).encode('utf-8'), dtype=np.int8), device=torch.cuda.current_device()
133133
)
134134
size = torch.as_tensor([string_tensor.size(0)], device=torch.cuda.current_device(), dtype=torch.int64)
135135
torch.distributed.broadcast(size, src, model_parallel_group)
@@ -140,7 +140,8 @@ def send_generate_info(
140140

141141
if context_start_idx is not None:
142142
context_idx_tensor = torch.as_tensor(
143-
np.frombuffer(pickle.dumps(context_start_idx), dtype=np.int8), device=torch.cuda.current_device()
143+
np.frombuffer(json.dumps(context_start_idx).encode('utf-8'), dtype=np.int8),
144+
device=torch.cuda.current_device(),
144145
)
145146
ctx_size = torch.as_tensor([context_idx_tensor.size(0)], device=torch.cuda.current_device(), dtype=torch.int64)
146147
torch.distributed.broadcast(ctx_size, src, model_parallel_group)
@@ -186,7 +187,7 @@ def receive_generate_info(has_multi_audios=False):
186187
string_tensor = torch.empty(array_size[0], dtype=torch.int8, device=torch.cuda.current_device())
187188
torch.distributed.broadcast(string_tensor, src, model_parallel_group)
188189
bytes = string_tensor.cpu().numpy().tobytes()
189-
end_strings = pickle.loads(bytes)
190+
end_strings = json.loads(bytes.decode('utf-8'))
190191

191192
num_audios = None
192193
context_start_idx = None
@@ -199,7 +200,7 @@ def receive_generate_info(has_multi_audios=False):
199200
context_idx_tensor = torch.empty(array_size[0], dtype=torch.int8, device=torch.cuda.current_device())
200201
torch.distributed.broadcast(context_idx_tensor, src, model_parallel_group)
201202
bytes = context_idx_tensor.cpu().numpy().tobytes()
202-
context_start_idx = pickle.loads(bytes)
203+
context_start_idx = json.loads(bytes.decode('utf-8'))
203204

204205
return (
205206
context_length_tensor,

nemo/collections/speechlm/utils/text_generation/audio_text_generation_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Utilities for generating text."""
1616

17-
import pickle
17+
import json
1818
from collections.abc import Iterable
1919
from typing import List, Optional, Tuple, TypedDict, Union
2020
import numpy as np
@@ -140,7 +140,7 @@ def send_generate_info(
140140

141141
# send end strings
142142
string_tensor = torch.as_tensor(
143-
np.frombuffer(pickle.dumps(end_strings), dtype=np.int8), device=torch.cuda.current_device()
143+
np.frombuffer(json.dumps(end_strings).encode('utf-8'), dtype=np.int8), device=torch.cuda.current_device()
144144
)
145145
size = torch.as_tensor([string_tensor.size(0)], device=torch.cuda.current_device(), dtype=torch.int64)
146146
torch.distributed.broadcast(size, src, model_parallel_group)
@@ -151,7 +151,8 @@ def send_generate_info(
151151

152152
if context_start_idx is not None:
153153
context_idx_tensor = torch.as_tensor(
154-
np.frombuffer(pickle.dumps(context_start_idx), dtype=np.int8), device=torch.cuda.current_device()
154+
np.frombuffer(json.dumps(context_start_idx).encode('utf-8'), dtype=np.int8),
155+
device=torch.cuda.current_device(),
155156
)
156157
ctx_size = torch.as_tensor([context_idx_tensor.size(0)], device=torch.cuda.current_device(), dtype=torch.int64)
157158
torch.distributed.broadcast(ctx_size, src, model_parallel_group)
@@ -197,7 +198,7 @@ def receive_generate_info(has_multi_audios=False):
197198
string_tensor = torch.empty(array_size[0], dtype=torch.int8, device=torch.cuda.current_device())
198199
torch.distributed.broadcast(string_tensor, src, model_parallel_group)
199200
bytes = string_tensor.cpu().numpy().tobytes()
200-
end_strings = pickle.loads(bytes)
201+
end_strings = json.loads(bytes.decode('utf-8'))
201202

202203
num_audios = None
203204
context_start_idx = None
@@ -210,7 +211,7 @@ def receive_generate_info(has_multi_audios=False):
210211
context_idx_tensor = torch.empty(array_size[0], dtype=torch.int8, device=torch.cuda.current_device())
211212
torch.distributed.broadcast(context_idx_tensor, src, model_parallel_group)
212213
bytes = context_idx_tensor.cpu().numpy().tobytes()
213-
context_start_idx = pickle.loads(bytes)
214+
context_start_idx = json.loads(bytes.decode('utf-8'))
214215

215216
return (
216217
context_length_tensor,

0 commit comments

Comments
 (0)