1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515import random
16- from typing import Dict , List , Optional
16+ from typing import Dict , Iterator , List , Optional , Tuple
1717
1818import torch
1919from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor
20+ from torch .utils .data .dataset import IterableDataset
2021
2122from .hstu_batch import FeatureConfig , HSTUBatch
2223
2324
24- class RandomInferenceDataGenerator :
25+ class RandomInferenceDataset ( IterableDataset [ Tuple [ HSTUBatch , torch . Tensor , torch . Tensor ]]) :
2526 """
2627 A random generator for the inference batches
2728
@@ -32,12 +33,12 @@ class RandomInferenceDataGenerator:
3233 action_feature_name (str): The action feature name.
3334 max_num_users (int): The maximum user numbers.
3435 max_batch_size (int): The maximum batch size.
35- max_seqlen (int): The maximum sequence length (with candidates) for item
36- in request per user. The length of action sequence in
37- request the same with that of HISTORY item sequence.
36+ max_history_length (int): The maximum history length for item in request per user.
37+ The length of action sequence in request is the same.
3838 max_num_candidates (int): The maximum candidates number.
3939 max_incremental_seqlen (int): The maximum incremental length of HISTORY
4040 item AND action sequence.
41+ max_num_cached_batches (int, optional): The number of batches to generate. Defaults to 1.
4142 full_mode (bool): The flag for full batch mode.
4243 """
4344
@@ -49,9 +50,10 @@ def __init__(
4950 action_feature_name : str = "" ,
5051 max_num_users : int = 1 ,
5152 max_batch_size : int = 32 ,
52- max_seqlen : int = 4096 ,
53+ max_history_length : int = 4096 ,
5354 max_num_candidates : int = 200 ,
5455 max_incremental_seqlen : int = 64 ,
56+ max_num_cached_batches : int = 1 ,
5557 full_mode : bool = False ,
5658 ):
5759 super ().__init__ ()
@@ -72,125 +74,88 @@ def __init__(
7274
7375 self ._max_num_users = min (max_num_users , 2 ** 16 )
7476 self ._max_batch_size = max_batch_size
75- self ._max_hist_len = max_seqlen - max_num_candidates
76- self ._max_incr_fea_len = max (max_incremental_seqlen , 1 )
77+ self ._max_hist_len = max_history_length
7778 self ._max_num_candidates = max_num_candidates
79+ self ._max_incr_fea_len = max (max_incremental_seqlen , 1 )
80+ self ._num_generated_batches = max (max_num_cached_batches , 1 )
7881
7982 self ._full_mode = full_mode
8083
8184 self ._item_history : Dict [int , torch .Tensor ] = dict ()
8285 self ._action_history : Dict [int , torch .Tensor ] = dict ()
8386
84- def get_inference_batch_user_ids (self ) -> Optional [torch .Tensor ]:
85- if self ._full_mode :
86- batch_size = self ._max_batch_size
87- user_ids = list (range (self ._max_batch_size ))
88- else :
89- batch_size = random .randint (1 , self ._max_batch_size )
90- user_ids = torch .randint (self ._max_num_users , (batch_size ,)).tolist ()
91- user_ids = list (set (user_ids ))
92-
93- user_ids = torch .tensor (
94- [
95- uid
96- for uid in user_ids
97- if uid not in self ._item_history
98- or len (self ._item_history [uid ]) < self ._max_hist_len
99- ]
100- ).long ()
101- if self ._full_mode and len (user_ids ) == 0 :
102- batch_size = self ._max_batch_size
103- user_ids = list (
104- range (
105- self ._max_batch_size ,
106- min (self ._max_batch_size * 2 , self ._max_num_users ),
87+ num_cached_batches = 0
88+ self ._cached_batch = list ()
89+ for seqlen_idx in range (max_incremental_seqlen , self ._max_hist_len , max_incremental_seqlen ):
90+ for idx in range (0 , self ._max_num_users , self ._max_batch_size ):
91+ if self ._full_mode :
92+ user_ids = list (range (idx , min (self ._max_num_users , idx + self ._max_batch_size )))
93+ else :
94+ user_ids = torch .randint (self ._max_num_users , (batch_size ,)).tolist ()
95+ user_ids = list (set (user_ids ))
96+
97+ batch_size = len (user_ids )
98+
99+ item_seq = list ()
100+ action_seq = list ()
101+ for uid in user_ids :
102+ if uid not in self ._item_history or uid not in self ._action_history :
103+ self ._item_history [uid ] = torch .randint (self ._max_item_id + 1 , (self ._max_hist_len + self ._max_num_candidates ,))
104+ self ._action_history [uid ] = torch .randint (self ._max_action_id + 1 , (self ._max_hist_len + self ._max_num_candidates ,))
105+
106+ item_seq .append (self ._item_history [uid ][:seqlen_idx + self ._max_num_candidates ])
107+ action_seq .append (self ._action_history [uid ][:seqlen_idx ])
108+ features = KeyedJaggedTensor .from_jt_dict (
109+ {
110+ self ._item_fea_name : JaggedTensor .from_dense (item_seq ),
111+ self ._action_fea_name : JaggedTensor .from_dense (action_seq ),
112+ }
107113 )
108- )
109- user_ids = torch .tensor (user_ids ).long ()
110- return user_ids if len (user_ids ) > 0 else None
111-
112- def get_random_inference_batch (
113- self , user_ids , truncate_start_positions
114- ) -> Optional [HSTUBatch ]:
115- batch_size = len (user_ids )
116- if batch_size == 0 :
117- return None
118- user_ids = user_ids .tolist ()
119- item_hists = [
120- self ._item_history [uid ] if uid in self ._item_history else torch .tensor ([])
121- for uid in user_ids
122- ]
123- action_hists = [
124- self ._action_history [uid ]
125- if uid in self ._action_history
126- else torch .tensor ([])
127- for uid in user_ids
128- ]
129-
130- lengths = torch .tensor ([len (hist_seq ) for hist_seq in item_hists ]).long ()
131- incr_lengths = torch .randint (
132- low = 1 , high = self ._max_incr_fea_len + 1 , size = (batch_size ,)
133- )
134- new_lengths = torch .clamp (lengths + incr_lengths , max = self ._max_hist_len ).long ()
135- incr_lengths = new_lengths - lengths
136-
137- num_candidates = torch .randint (
138- low = 1 , high = self ._max_num_candidates + 1 , size = (batch_size ,)
139- )
140- if self ._full_mode :
141- incr_lengths = torch .full ((batch_size ,), self ._max_incr_fea_len )
142- new_lengths = torch .clamp (
143- lengths + incr_lengths , max = self ._max_hist_len
144- ).long ()
145- incr_lengths = new_lengths - lengths
146- num_candidates = torch .full ((batch_size ,), self ._max_num_candidates )
147-
148- # Caveats: truncate_start_positions is for interleaved item-action sequence
149- item_start_positions = (truncate_start_positions / 2 ).to (torch .int32 )
150- action_start_positions = (truncate_start_positions / 2 ).to (torch .int32 )
151-
152- item_seq = list ()
153- action_seq = list ()
154- for idx , uid in enumerate (user_ids ):
155- self ._item_history [uid ] = torch .cat (
156- [
157- item_hists [idx ],
158- torch .randint (self ._max_item_id + 1 , (incr_lengths [idx ],)),
159- ],
160- dim = 0 ,
161- ).long ()
162- self ._action_history [uid ] = torch .cat (
163- [
164- action_hists [idx ],
165- torch .randint (self ._max_action_id + 1 , (incr_lengths [idx ],)),
166- ],
167- dim = 0 ,
168- ).long ()
169-
170- item_history = torch .cat (
171- [
172- self ._item_history [uid ][item_start_positions [idx ] :],
173- torch .randint (self ._max_item_id + 1 , (num_candidates [idx ].item (),)),
174- ],
175- dim = 0 ,
176- )
177- item_seq .append (item_history )
178- action_seq .append (self ._action_history [uid ][action_start_positions [idx ] :])
179-
180- features = KeyedJaggedTensor .from_jt_dict (
181- {
182- self ._item_fea_name : JaggedTensor .from_dense (item_seq ),
183- self ._action_fea_name : JaggedTensor .from_dense (action_seq ),
184- }
185- )
186-
187- return HSTUBatch (
188- features = features ,
189- batch_size = batch_size ,
190- feature_to_max_seqlen = self ._fea_name_to_max_seqlen ,
191- contextual_feature_names = self ._contextual_fea_names ,
192- item_feature_name = self ._item_fea_name ,
193- action_feature_name = self ._action_fea_name ,
194- max_num_candidates = self ._max_num_candidates ,
195- num_candidates = num_candidates ,
196- )
114+
115+ if self ._full_mode :
116+ num_candidates = torch .full ((batch_size ,), self ._max_num_candidates )
117+ else :
118+ num_candidates = torch .randint (
119+ low = 1 , high = self ._max_num_candidates + 1 , size = (batch_size ,)
120+ )
121+
122+ total_history_lengths = torch .full ((batch_size ,), seqlen_idx * 2 )
123+
124+ batch = HSTUBatch (
125+ features = features ,
126+ batch_size = batch_size ,
127+ feature_to_max_seqlen = self ._fea_name_to_max_seqlen ,
128+ contextual_feature_names = self ._contextual_fea_names ,
129+ item_feature_name = self ._item_fea_name ,
130+ action_feature_name = self ._action_fea_name ,
131+ max_num_candidates = self ._max_num_candidates ,
132+ num_candidates = num_candidates ,
133+ ).to (device = torch .cuda .current_device ())
134+ self ._cached_batch .append (tuple ([batch , torch .tensor (user_ids ).long (), total_history_lengths ]))
135+ num_cached_batches += 1
136+ if num_cached_batches >= self ._num_generated_batches :
137+ break
138+
139+ self ._num_generated_batches = len (self ._cached_batch )
140+ self ._max_num_batches = self ._num_generated_batches
141+ self ._iloc = 0
142+
143+ def __iter__ (self ) -> Iterator [Tuple [HSTUBatch , torch .Tensor , torch .Tensor ]]:
144+ """
145+ Returns an iterator over the cached batches, cycling through them.
146+
147+ Returns:
148+ Tuple[HSTUBatch, torch.Tensor, torch.Tensor]: The next (batch, user_ids, history_lens) in the cycle.
149+ """
150+ for _ in range (len (self )):
151+ yield self ._cached_batch [self ._iloc ]
152+ self ._iloc = (self ._iloc + 1 ) % self ._num_generated_batches
153+
154+ def __len__ (self ) -> int :
155+ """
156+ Get the number of batches.
157+
158+ Returns:
159+ int: The number of batches.
160+ """
161+ return self ._max_num_batches
0 commit comments