1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15- import random
16- from typing import Dict , List , Optional
15+ from typing import Dict , Iterator , List , Tuple
1716
1817import torch
18+ from torch .utils .data .dataset import IterableDataset
1919from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor
2020
2121from .hstu_batch import FeatureConfig , HSTUBatch
2222
2323
24- class RandomInferenceDataGenerator :
24+ class RandomInferenceDataset (
25+ IterableDataset [Tuple [HSTUBatch , torch .Tensor , torch .Tensor ]]
26+ ):
2527 """
2628 A random generator for the inference batches
2729
@@ -32,12 +34,12 @@ class RandomInferenceDataGenerator:
3234 action_feature_name (str): The action feature name.
3335 max_num_users (int): The maximum user numbers.
3436 max_batch_size (int): The maximum batch size.
35- max_seqlen (int): The maximum sequence length (with candidates) for item
36- in request per user. The length of action sequence in
37- request the same with that of HISTORY item sequence.
37+ max_history_length (int): The maximum history length for item in request per user.
38+ The length of action sequence in request is the same.
3839 max_num_candidates (int): The maximum candidates number.
3940 max_incremental_seqlen (int): The maximum incremental length of HISTORY
4041 item AND action sequence.
42+ max_num_cached_batches (int, optional): The number of batches to generate. Defaults to 1.
4143 full_mode (bool): The flag for full batch mode.
4244 """
4345
@@ -49,9 +51,10 @@ def __init__(
4951 action_feature_name : str = "" ,
5052 max_num_users : int = 1 ,
5153 max_batch_size : int = 32 ,
52- max_seqlen : int = 4096 ,
54+ max_history_length : int = 4096 ,
5355 max_num_candidates : int = 200 ,
5456 max_incremental_seqlen : int = 64 ,
57+ max_num_cached_batches : int = 1 ,
5558 full_mode : bool = False ,
5659 ):
5760 super ().__init__ ()
@@ -72,125 +75,104 @@ def __init__(
7275
7376 self ._max_num_users = min (max_num_users , 2 ** 16 )
7477 self ._max_batch_size = max_batch_size
75- self ._max_hist_len = max_seqlen - max_num_candidates
76- self ._max_incr_fea_len = max (max_incremental_seqlen , 1 )
78+ self ._max_hist_len = max_history_length
7779 self ._max_num_candidates = max_num_candidates
80+ self ._max_incr_fea_len = max (max_incremental_seqlen , 1 )
81+ self ._num_generated_batches = max (max_num_cached_batches , 1 )
7882
7983 self ._full_mode = full_mode
8084
8185 self ._item_history : Dict [int , torch .Tensor ] = dict ()
8286 self ._action_history : Dict [int , torch .Tensor ] = dict ()
8387
84- def get_inference_batch_user_ids (self ) -> Optional [torch .Tensor ]:
85- if self ._full_mode :
86- batch_size = self ._max_batch_size
87- user_ids = list (range (self ._max_batch_size ))
88- else :
89- batch_size = random .randint (1 , self ._max_batch_size )
90- user_ids = torch .randint (self ._max_num_users , (batch_size ,)).tolist ()
91- user_ids = list (set (user_ids ))
92-
93- user_ids = torch .tensor (
94- [
95- uid
96- for uid in user_ids
97- if uid not in self ._item_history
98- or len (self ._item_history [uid ]) < self ._max_hist_len
99- ]
100- ).long ()
101- if self ._full_mode and len (user_ids ) == 0 :
102- batch_size = self ._max_batch_size
103- user_ids = list (
104- range (
105- self ._max_batch_size ,
106- min (self ._max_batch_size * 2 , self ._max_num_users ),
88+ num_cached_batches = 0
89+ self ._cached_batch = list ()
90+ for seqlen_idx in range (
91+ max_incremental_seqlen , self ._max_hist_len , max_incremental_seqlen
92+ ):
93+ for idx in range (0 , self ._max_num_users , self ._max_batch_size ):
94+ if self ._full_mode :
95+ user_ids = list (
96+ range (idx , min (self ._max_num_users , idx + self ._max_batch_size ))
97+ )
98+ else :
99+ user_ids = torch .randint (
100+ self ._max_num_users , (self ._max_batch_size ,)
101+ ).tolist ()
102+ user_ids = list (set (user_ids ))
103+
104+ batch_size = len (user_ids )
105+
106+ item_seq = list ()
107+ action_seq = list ()
108+ for uid in user_ids :
109+ if uid not in self ._item_history or uid not in self ._action_history :
110+ self ._item_history [uid ] = torch .randint (
111+ self ._max_item_id + 1 ,
112+ (self ._max_hist_len + self ._max_num_candidates ,),
113+ )
114+ self ._action_history [uid ] = torch .randint (
115+ self ._max_action_id + 1 ,
116+ (self ._max_hist_len + self ._max_num_candidates ,),
117+ )
118+
119+ item_seq .append (
120+ self ._item_history [uid ][: seqlen_idx + self ._max_num_candidates ]
121+ )
122+ action_seq .append (self ._action_history [uid ][:seqlen_idx ])
123+ features = KeyedJaggedTensor .from_jt_dict (
124+ {
125+ self ._item_fea_name : JaggedTensor .from_dense (item_seq ),
126+ self ._action_fea_name : JaggedTensor .from_dense (action_seq ),
127+ }
128+ )
129+
130+ if self ._full_mode :
131+ num_candidates = torch .full ((batch_size ,), self ._max_num_candidates )
132+ else :
133+ num_candidates = torch .randint (
134+ low = 1 , high = self ._max_num_candidates + 1 , size = (batch_size ,)
135+ )
136+
137+ total_history_lengths = torch .full ((batch_size ,), seqlen_idx * 2 )
138+
139+ batch = HSTUBatch (
140+ features = features ,
141+ batch_size = batch_size ,
142+ feature_to_max_seqlen = self ._fea_name_to_max_seqlen ,
143+ contextual_feature_names = self ._contextual_fea_names ,
144+ item_feature_name = self ._item_fea_name ,
145+ action_feature_name = self ._action_fea_name ,
146+ max_num_candidates = self ._max_num_candidates ,
147+ num_candidates = num_candidates ,
148+ ).to (device = torch .cuda .current_device ())
149+ self ._cached_batch .append (
150+ tuple ([batch , torch .tensor (user_ids ).long (), total_history_lengths ])
107151 )
108- )
109- user_ids = torch .tensor (user_ids ).long ()
110- return user_ids if len (user_ids ) > 0 else None
111-
112- def get_random_inference_batch (
113- self , user_ids , truncate_start_positions
114- ) -> Optional [HSTUBatch ]:
115- batch_size = len (user_ids )
116- if batch_size == 0 :
117- return None
118- user_ids = user_ids .tolist ()
119- item_hists = [
120- self ._item_history [uid ] if uid in self ._item_history else torch .tensor ([])
121- for uid in user_ids
122- ]
123- action_hists = [
124- self ._action_history [uid ]
125- if uid in self ._action_history
126- else torch .tensor ([])
127- for uid in user_ids
128- ]
129-
130- lengths = torch .tensor ([len (hist_seq ) for hist_seq in item_hists ]).long ()
131- incr_lengths = torch .randint (
132- low = 1 , high = self ._max_incr_fea_len + 1 , size = (batch_size ,)
133- )
134- new_lengths = torch .clamp (lengths + incr_lengths , max = self ._max_hist_len ).long ()
135- incr_lengths = new_lengths - lengths
136-
137- num_candidates = torch .randint (
138- low = 1 , high = self ._max_num_candidates + 1 , size = (batch_size ,)
139- )
140- if self ._full_mode :
141- incr_lengths = torch .full ((batch_size ,), self ._max_incr_fea_len )
142- new_lengths = torch .clamp (
143- lengths + incr_lengths , max = self ._max_hist_len
144- ).long ()
145- incr_lengths = new_lengths - lengths
146- num_candidates = torch .full ((batch_size ,), self ._max_num_candidates )
147-
148- # Caveats: truncate_start_positions is for interleaved item-action sequence
149- item_start_positions = (truncate_start_positions / 2 ).to (torch .int32 )
150- action_start_positions = (truncate_start_positions / 2 ).to (torch .int32 )
151-
152- item_seq = list ()
153- action_seq = list ()
154- for idx , uid in enumerate (user_ids ):
155- self ._item_history [uid ] = torch .cat (
156- [
157- item_hists [idx ],
158- torch .randint (self ._max_item_id + 1 , (incr_lengths [idx ],)),
159- ],
160- dim = 0 ,
161- ).long ()
162- self ._action_history [uid ] = torch .cat (
163- [
164- action_hists [idx ],
165- torch .randint (self ._max_action_id + 1 , (incr_lengths [idx ],)),
166- ],
167- dim = 0 ,
168- ).long ()
169-
170- item_history = torch .cat (
171- [
172- self ._item_history [uid ][item_start_positions [idx ] :],
173- torch .randint (self ._max_item_id + 1 , (num_candidates [idx ].item (),)),
174- ],
175- dim = 0 ,
176- )
177- item_seq .append (item_history )
178- action_seq .append (self ._action_history [uid ][action_start_positions [idx ] :])
179-
180- features = KeyedJaggedTensor .from_jt_dict (
181- {
182- self ._item_fea_name : JaggedTensor .from_dense (item_seq ),
183- self ._action_fea_name : JaggedTensor .from_dense (action_seq ),
184- }
185- )
186-
187- return HSTUBatch (
188- features = features ,
189- batch_size = batch_size ,
190- feature_to_max_seqlen = self ._fea_name_to_max_seqlen ,
191- contextual_feature_names = self ._contextual_fea_names ,
192- item_feature_name = self ._item_fea_name ,
193- action_feature_name = self ._action_fea_name ,
194- max_num_candidates = self ._max_num_candidates ,
195- num_candidates = num_candidates ,
196- )
152+ num_cached_batches += 1
153+ if num_cached_batches >= self ._num_generated_batches :
154+ break
155+
156+ self ._num_generated_batches = len (self ._cached_batch )
157+ self ._max_num_batches = self ._num_generated_batches
158+ self ._iloc = 0
159+
160+ def __iter__ (self ) -> Iterator [Tuple [HSTUBatch , torch .Tensor , torch .Tensor ]]:
161+ """
162+ Returns an iterator over the cached batches, cycling through them.
163+
164+ Returns:
165+ Tuple[HSTUBatch, torch.Tensor, torch.Tensor]: The next (batch, user_ids, history_lens) in the cycle.
166+ """
167+ for _ in range (len (self )):
168+ yield self ._cached_batch [self ._iloc ]
169+ self ._iloc = (self ._iloc + 1 ) % self ._num_generated_batches
170+
171+ def __len__ (self ) -> int :
172+ """
173+ Get the number of batches.
174+
175+ Returns:
176+ int: The number of batches.
177+ """
178+ return self ._max_num_batches
0 commit comments