Skip to content

Commit bed00ee

Browse files
[Automated Commit] Format Codebase
1 parent 5bc70eb commit bed00ee

38 files changed

+664
-306
lines changed

recommendation/dlrm_v3/accuracy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ def main() -> None:
6767
num_candidates = data[-1].astype(int)
6868
assert len(data) == 1 + num_candidates * 3
6969
mt_target_preds = torch.from_numpy(data[0:num_candidates])
70-
mt_target_labels = torch.from_numpy(data[num_candidates : num_candidates * 2])
70+
mt_target_labels = torch.from_numpy(
71+
data[num_candidates: num_candidates * 2])
7172
mt_target_weights = torch.from_numpy(
72-
data[num_candidates * 2 : num_candidates * 3]
73+
data[num_candidates * 2: num_candidates * 3]
7374
)
7475
num_candidates = torch.tensor([num_candidates])
7576
metrics.update(

recommendation/dlrm_v3/checkpoint.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class SparseState(Stateful):
4646
sparse_tensor_keys: Set of keys identifying sparse tensors in the model's state dict.
4747
"""
4848

49-
def __init__(self, model: torch.nn.Module, sparse_tensor_keys: Set[str]) -> None:
49+
def __init__(self, model: torch.nn.Module,
50+
sparse_tensor_keys: Set[str]) -> None:
5051
self.model = model
5152
self.sparse_tensor_keys = sparse_tensor_keys
5253

@@ -62,17 +63,23 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
6263
return out_dict
6364

6465
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
65-
incompatible_keys = self.model.load_state_dict(state_dict, strict=False)
66+
incompatible_keys = self.model.load_state_dict(
67+
state_dict, strict=False)
6668
assert not incompatible_keys.unexpected_keys
6769

6870

6971
def is_sparse_key(k: str, v: torch.Tensor) -> bool:
7072
return isinstance(v, ShardedTensor) or "embedding_collection" in k
7173

7274

73-
def load_dense_state_dict(model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
75+
def load_dense_state_dict(model: torch.nn.Module,
76+
state_dict: Dict[str, Any]) -> None:
7477
own_state = model.state_dict()
75-
own_state_dense_keys = {k for k, v in own_state.items() if not is_sparse_key(k, v)}
78+
own_state_dense_keys = {
79+
k for k,
80+
v in own_state.items() if not is_sparse_key(
81+
k,
82+
v)}
7683
state_dict_dense_keys = {
7784
k for k, v in state_dict.items() if not is_sparse_key(k, v)
7885
}
@@ -156,7 +163,8 @@ def save_dmp_checkpoint(
156163
sparse_dict = {"sparse_dict": SparseState(model, sparse_tensor_keys)}
157164
torch.distributed.checkpoint.save(
158165
sparse_dict,
159-
storage_writer=torch.distributed.checkpoint.FileSystemWriter(sparse_path),
166+
storage_writer=torch.distributed.checkpoint.FileSystemWriter(
167+
sparse_path),
160168
)
161169
torch.distributed.barrier()
162170
print("checkpoint successfully saved")
@@ -178,7 +186,8 @@ def load_sparse_checkpoint(
178186
gc.collect()
179187
torch.distributed.checkpoint.load(
180188
sparse_dict,
181-
storage_reader=torch.distributed.checkpoint.FileSystemReader(sparse_path),
189+
storage_reader=torch.distributed.checkpoint.FileSystemReader(
190+
sparse_path),
182191
)
183192
gc.collect()
184193
print("sparse checkpoint successfully loaded")

recommendation/dlrm_v3/configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def get_hstu_configs(dataset: str = "debug") -> DlrmHSTUConfig:
114114
return hstu_config
115115

116116

117-
def get_embedding_table_config(dataset: str = "debug") -> Dict[str, EmbeddingConfig]:
117+
def get_embedding_table_config(
118+
dataset: str = "debug") -> Dict[str, EmbeddingConfig]:
118119
"""
119120
Create and return embedding table configurations.
120121

recommendation/dlrm_v3/data_producer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def enqueue(
9090
"""
9191
with torch.profiler.record_function("data batching"):
9292
t0_batching: float = time.time()
93-
samples: Union[Samples, List[Samples]] = self.ds.get_samples(content_ids)
93+
samples: Union[Samples, List[Samples]
94+
] = self.ds.get_samples(content_ids)
9495
dt_batching: float = time.time() - t0_batching
9596
if isinstance(samples, Samples):
9697
query = QueryItem(
@@ -106,7 +107,7 @@ def enqueue(
106107
for sample in samples:
107108
batch_size: int = sample.batch_size()
108109
query = QueryItem(
109-
query_ids=query_ids[start_idx : start_idx + batch_size],
110+
query_ids=query_ids[start_idx: start_idx + batch_size],
110111
samples=sample,
111112
start=t0,
112113
dt_queue=dt_queue,
@@ -148,7 +149,9 @@ def __init__(
148149
)
149150
self.workers: List[threading.Thread] = []
150151
for _ in range(self.threads):
151-
worker = threading.Thread(target=self.handle_tasks, args=(self.tasks,))
152+
worker = threading.Thread(
153+
target=self.handle_tasks, args=(
154+
self.tasks,))
152155
worker.daemon = True
153156
self.workers.append(worker)
154157
worker.start()
@@ -172,7 +175,8 @@ def handle_tasks(
172175
break
173176
query_ids, content_ids, t0, dt_queue = query_and_content_ids
174177
t0_batching: float = time.time()
175-
samples: Union[Samples, List[Samples]] = self.ds.get_samples(content_ids)
178+
samples: Union[Samples, List[Samples]
179+
] = self.ds.get_samples(content_ids)
176180
dt_batching: float = time.time() - t0_batching
177181
if isinstance(samples, Samples):
178182
qitem = QueryItem(
@@ -189,7 +193,7 @@ def handle_tasks(
189193
for sample in samples:
190194
batch_size: int = sample.batch_size()
191195
qitem = QueryItem(
192-
query_ids=query_ids[start_idx : start_idx + batch_size],
196+
query_ids=query_ids[start_idx: start_idx + batch_size],
193197
samples=sample,
194198
start=t0,
195199
dt_queue=dt_queue,

recommendation/dlrm_v3/datasets/dataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,13 @@ def kjt_batch_func(
204204
bs_offset = torch.ops.fbgemm.asynchronous_complete_cumsum(
205205
torch.tensor(bs_list)
206206
).int()
207-
batched_offset = torch.ops.fbgemm.asynchronous_complete_cumsum(batched_length)
207+
batched_offset = torch.ops.fbgemm.asynchronous_complete_cumsum(
208+
batched_length)
208209
reorder_length = torch.ops.fbgemm.reorder_batched_ad_lengths(
209210
batched_length, bs_offset, bs
210211
)
211-
reorder_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(reorder_length)
212+
reorder_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
213+
reorder_length)
212214
reorder_indices = torch.ops.fbgemm.reorder_batched_ad_indices(
213215
batched_offset, batched_indices, reorder_offsets, bs_offset, bs
214216
)
@@ -345,7 +347,8 @@ def __init__(
345347
self.num_aggregated_samples = num_aggregated_samples
346348
self.items_in_memory = {}
347349

348-
def get_sample(self, id: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]:
350+
def get_sample(
351+
self, id: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]:
349352
"""
350353
Get a sample by ID from in-memory storage.
351354

recommendation/dlrm_v3/datasets/synthetic_streaming.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def load_query_samples(self, sample_list: List[int]) -> None:
146146
def unload_query_samples(self, sample_list: List[int]) -> None:
147147
self.items_in_memory = {}
148148

149-
def get_sample(self, id: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]:
149+
def get_sample(
150+
self, id: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]:
150151
return self.items_in_memory[self.ts][id]
151152

152153
def get_sample_with_ts(
@@ -192,7 +193,8 @@ def _process_line(self, line: str, user_id: int) -> pd.Series:
192193
reader = csv.reader([line])
193194
parsed_line = next(reader)
194195
# total ts + one more eval ts + one base ts so that uih won't be zero
195-
# for each ts, ordered as candidate_ids, candidate_ratings, uih_ids, uih_ratings
196+
# for each ts, ordered as candidate_ids, candidate_ratings, uih_ids,
197+
# uih_ratings
196198
assert len(parsed_line) == 4 * (self.total_ts + 2)
197199
uih_item_ids_list = []
198200
uih_ratings_list = []
@@ -290,7 +292,8 @@ def set_ts(self, ts: int) -> None:
290292
assert len(row) == 1
291293
requests = json_loads(row[0])
292294
self.requests = requests
293-
logger.warning(f"DLRMv3SyntheticStreamingDataset: ts={ts} requests loaded")
295+
logger.warning(
296+
f"DLRMv3SyntheticStreamingDataset: ts={ts} requests loaded")
294297
assert self.ts_to_users_cumsum[self.ts][-1] == len(self.requests)
295298
logger.warning(
296299
f"DLRMv3SyntheticStreamingDataset: ts={ts} users_cumsum={self.ts_to_users_cumsum[self.ts]}"
@@ -336,7 +339,8 @@ def load_item(
336339
timestamps_uih = maybe_truncate_seq(timestamps_uih, self._max_uih_len)
337340
ids_candidates = maybe_truncate_seq(ids_candidates, max_num_candidates)
338341
num_candidates = len(ids_candidates)
339-
ratings_candidates = maybe_truncate_seq(ratings_candidates, max_num_candidates)
342+
ratings_candidates = maybe_truncate_seq(
343+
ratings_candidates, max_num_candidates)
340344
action_weights_uih = [
341345
self.action_weights[int(rating) - 1] for rating in ratings_uih
342346
]
@@ -366,7 +370,8 @@ def load_item(
366370
[
367371
uih_seq_len
368372
for _ in range(
369-
len(self._uih_keys) - len(self._contextual_feature_to_max_length)
373+
len(self._uih_keys) -
374+
len(self._contextual_feature_to_max_length)
370375
)
371376
]
372377
)
@@ -380,7 +385,8 @@ def load_item(
380385
values=torch.tensor(uih_kjt_values).long(),
381386
)
382387

383-
candidates_kjt_lengths = num_candidates * torch.ones(len(self._candidates_keys))
388+
candidates_kjt_lengths = num_candidates * \
389+
torch.ones(len(self._candidates_keys))
384390
item_candidate_category_ids = [
385391
id // self.items_per_category for id in ids_candidates
386392
]

recommendation/dlrm_v3/datasets/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def json_loads(
4545
y = json.loads(x)
4646
else:
4747
y = x
48-
y_list = [y] if type(y) == int else list(y)
48+
y_list = [y] if isinstance(y, int) else list(y)
4949
return y_list
5050

5151

@@ -72,7 +72,7 @@ def separate_uih_candidates(
7272
y = json.loads(x)
7373
else:
7474
y = x
75-
y_list = [y] if type(y) == int else list(y)
75+
y_list = [y] if isinstance(y, int) else list(y)
7676
candidates, uih = (
7777
y_list[-candidates_max_seq_len:],
7878
y_list[:-candidates_max_seq_len],

recommendation/dlrm_v3/generative_recommenders/common.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ def generate_sparse_seq_len(
188188
if sparsity == 0.0:
189189
return torch.zeros(size=(size,), device=device, dtype=torch.int)
190190
elif sparsity == 1.0:
191-
return torch.ones(size=(size,), device=device, dtype=torch.int) * max_seq_len
191+
return torch.ones(size=(size,), device=device,
192+
dtype=torch.int) * max_seq_len
192193
elif sparsity >= 0.5:
193194
min_seq_len: int = int((2 * sparsity - 1.0) * max_seq_len)
194195
return torch.randint(
@@ -265,10 +266,12 @@ def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor:
265266
def prev_power_of_2(x: int) -> int:
266267
if torch.compiler.is_compiling():
267268
# Re-write to make Dynamo happy
268-
x_tensor = torch.scalar_tensor(x, dtype=torch.int64) # type: ignore[arg-type]
269+
x_tensor = torch.scalar_tensor(
270+
x, dtype=torch.int64) # type: ignore[arg-type]
269271
x_tensor_orig = x_tensor.clone()
270272
out = triton.next_power_of_2(x_tensor) # type: ignore[arg-type]
271-
return int(torch.where(torch.lt(x_tensor_orig, out), out // 2, out).item()) # type: ignore[return-value]
273+
return int(torch.where(torch.lt(x_tensor_orig, out), out //
274+
2, out).item()) # type: ignore[return-value]
272275
else:
273276
out = triton.next_power_of_2(x)
274277
return out // 2 if out > x else out
@@ -340,7 +343,9 @@ def _generate_fine_grained_buckets() -> List[int]:
340343
def _fine_grained_bucket_size(x: int) -> int:
341344
if torch.compiler.is_compiling():
342345
x_tensor = torch.scalar_tensor(x, dtype=torch.int64)
343-
buckets = torch.tensor(_generate_fine_grained_buckets(), dtype=torch.int64)
346+
buckets = torch.tensor(
347+
_generate_fine_grained_buckets(),
348+
dtype=torch.int64)
344349

345350
mask = buckets >= x_tensor
346351
valid_buckets = torch.where(
@@ -361,7 +366,8 @@ def _fine_grained_bucket_size(x: int) -> int:
361366

362367

363368
@torch.fx.wrap
364-
def fx_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor:
369+
def fx_unwrap_optional_tensor(
370+
optional: Optional[torch.Tensor]) -> torch.Tensor:
365371
assert optional is not None, "Expected optional to be non-None Tensor"
366372
return optional
367373

recommendation/dlrm_v3/generative_recommenders/modules/action_encoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def forward(
8585
watchtimes = seq_payloads[self._watchtime_feature_name]
8686
for threshold, weight in self._watchtime_to_action_thresholds_and_weights:
8787
seq_actions = torch.bitwise_or(
88-
seq_actions, (watchtimes >= threshold).to(torch.int64) * weight
88+
seq_actions, (watchtimes >= threshold).to(
89+
torch.int64) * weight
8990
)
9091
exploded_actions = (
9192
torch.bitwise_and(
@@ -94,7 +95,8 @@ def forward(
9495
> 0
9596
)
9697
action_embeddings = (
97-
exploded_actions.unsqueeze(-1) * self._action_embedding_table.unsqueeze(0)
98+
exploded_actions.unsqueeze(-1) *
99+
self._action_embedding_table.unsqueeze(0)
98100
).view(-1, self._num_action_types * self._action_embedding_dim)
99101
total_targets: int = seq_embeddings.size(0) - action_embeddings.size(0)
100102
action_embeddings = concat_2D_jagged(

recommendation/dlrm_v3/generative_recommenders/modules/content_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def forward(
7979
if self._target_enrich_dummy_embeddings:
8080
total_seq_len: int = seq_embeddings.size(0)
8181
for name, param in self._target_enrich_dummy_embeddings.items():
82-
enrich_embeddings_target = seq_payloads[name].to(seq_embeddings.dtype)
82+
enrich_embeddings_target = seq_payloads[name].to(
83+
seq_embeddings.dtype)
8384
total_targets: int = enrich_embeddings_target.size(0)
8485
total_uih_len: int = total_seq_len - total_targets
8586
enrich_embeddings_uih = param.tile(total_uih_len, 1).to(

0 commit comments

Comments
 (0)