@@ -146,7 +146,8 @@ def load_query_samples(self, sample_list: List[int]) -> None:
146146 def unload_query_samples (self , sample_list : List [int ]) -> None :
147147 self .items_in_memory = {}
148148
149- def get_sample (self , id : int ) -> Tuple [KeyedJaggedTensor , KeyedJaggedTensor ]:
149+ def get_sample (
150+ self , id : int ) -> Tuple [KeyedJaggedTensor , KeyedJaggedTensor ]:
150151 return self .items_in_memory [self .ts ][id ]
151152
152153 def get_sample_with_ts (
@@ -192,7 +193,8 @@ def _process_line(self, line: str, user_id: int) -> pd.Series:
192193 reader = csv .reader ([line ])
193194 parsed_line = next (reader )
194195 # total ts + one more eval ts + one base ts so that uih won't be zero
195- # for each ts, ordered as candidate_ids, candidate_ratings, uih_ids, uih_ratings
196+ # for each ts, ordered as candidate_ids, candidate_ratings, uih_ids,
197+ # uih_ratings
196198 assert len (parsed_line ) == 4 * (self .total_ts + 2 )
197199 uih_item_ids_list = []
198200 uih_ratings_list = []
@@ -290,7 +292,8 @@ def set_ts(self, ts: int) -> None:
290292 assert len (row ) == 1
291293 requests = json_loads (row [0 ])
292294 self .requests = requests
293- logger .warning (f"DLRMv3SyntheticStreamingDataset: ts={ ts } requests loaded" )
295+ logger .warning (
296+ f"DLRMv3SyntheticStreamingDataset: ts={ ts } requests loaded" )
294297 assert self .ts_to_users_cumsum [self .ts ][- 1 ] == len (self .requests )
295298 logger .warning (
296299 f"DLRMv3SyntheticStreamingDataset: ts={ ts } users_cumsum={ self .ts_to_users_cumsum [self .ts ]} "
@@ -336,7 +339,8 @@ def load_item(
336339 timestamps_uih = maybe_truncate_seq (timestamps_uih , self ._max_uih_len )
337340 ids_candidates = maybe_truncate_seq (ids_candidates , max_num_candidates )
338341 num_candidates = len (ids_candidates )
339- ratings_candidates = maybe_truncate_seq (ratings_candidates , max_num_candidates )
342+ ratings_candidates = maybe_truncate_seq (
343+ ratings_candidates , max_num_candidates )
340344 action_weights_uih = [
341345 self .action_weights [int (rating ) - 1 ] for rating in ratings_uih
342346 ]
@@ -366,7 +370,8 @@ def load_item(
366370 [
367371 uih_seq_len
368372 for _ in range (
369- len (self ._uih_keys ) - len (self ._contextual_feature_to_max_length )
373+ len (self ._uih_keys ) -
374+ len (self ._contextual_feature_to_max_length )
370375 )
371376 ]
372377 )
@@ -380,7 +385,8 @@ def load_item(
380385 values = torch .tensor (uih_kjt_values ).long (),
381386 )
382387
383- candidates_kjt_lengths = num_candidates * torch .ones (len (self ._candidates_keys ))
388+ candidates_kjt_lengths = num_candidates * \
389+ torch .ones (len (self ._candidates_keys ))
384390 item_candidate_category_ids = [
385391 id // self .items_per_category for id in ids_candidates
386392 ]
0 commit comments