1616
1717"""Utilities for generating text."""
1818
19- import pickle
19+ import json
2020import sys
2121from collections .abc import Iterable
2222from typing import TYPE_CHECKING , List , Optional , Tuple , Union
@@ -129,7 +129,7 @@ def send_generate_info(
129129
130130 # send end strings
131131 string_tensor = torch .as_tensor (
132- np .frombuffer (pickle .dumps (end_strings ), dtype = np .int8 ), device = torch .cuda .current_device ()
132+ np .frombuffer (json .dumps (end_strings ). encode ( 'utf-8' ), dtype = np .int8 ), device = torch .cuda .current_device ()
133133 )
134134 size = torch .as_tensor ([string_tensor .size (0 )], device = torch .cuda .current_device (), dtype = torch .int64 )
135135 torch .distributed .broadcast (size , src , model_parallel_group )
@@ -140,7 +140,8 @@ def send_generate_info(
140140
141141 if context_start_idx is not None :
142142 context_idx_tensor = torch .as_tensor (
143- np .frombuffer (pickle .dumps (context_start_idx ), dtype = np .int8 ), device = torch .cuda .current_device ()
143+ np .frombuffer (json .dumps (context_start_idx ).encode ('utf-8' ), dtype = np .int8 ),
144+ device = torch .cuda .current_device (),
144145 )
145146 ctx_size = torch .as_tensor ([context_idx_tensor .size (0 )], device = torch .cuda .current_device (), dtype = torch .int64 )
146147 torch .distributed .broadcast (ctx_size , src , model_parallel_group )
@@ -186,7 +187,7 @@ def receive_generate_info(has_multi_audios=False):
186187 string_tensor = torch .empty (array_size [0 ], dtype = torch .int8 , device = torch .cuda .current_device ())
187188 torch .distributed .broadcast (string_tensor , src , model_parallel_group )
188189 bytes = string_tensor .cpu ().numpy ().tobytes ()
189- end_strings = pickle .loads (bytes )
190+ end_strings = json .loads (bytes . decode ( 'utf-8' ) )
190191
191192 num_audios = None
192193 context_start_idx = None
@@ -199,7 +200,7 @@ def receive_generate_info(has_multi_audios=False):
199200 context_idx_tensor = torch .empty (array_size [0 ], dtype = torch .int8 , device = torch .cuda .current_device ())
200201 torch .distributed .broadcast (context_idx_tensor , src , model_parallel_group )
201202 bytes = context_idx_tensor .cpu ().numpy ().tobytes ()
202- context_start_idx = pickle .loads (bytes )
203+ context_start_idx = json .loads (bytes . decode ( 'utf-8' ) )
203204
204205 return (
205206 context_length_tensor ,
0 commit comments