Skip to content

Commit b1ca740

Browse files
committed
Fix kvcache error and optimize aten tensors
1 parent 012728b commit b1ca740

File tree

8 files changed

+575
-86
lines changed

8 files changed

+575
-86
lines changed
Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import argparse
16+
import enum
17+
import math
18+
import sys
19+
import time
20+
import os
21+
import shutil
22+
23+
import gin
24+
import torch
25+
from commons.utils.stringify import stringify_dict
26+
from configs import (
27+
InferenceEmbeddingConfig,
28+
PositionEncodingConfig,
29+
RankingConfig,
30+
get_inference_hstu_config,
31+
get_kvcache_config,
32+
)
33+
from dataset import get_data_loader
34+
from dataset.inference_dataset import InferenceDataset
35+
from dataset.sequence_dataset import get_dataset
36+
from modules.metrics import get_multi_event_metric_module
37+
from preprocessor import get_common_preprocessors
38+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
39+
from utils import DatasetArgs, NetworkArgs, RankingArgs
40+
41+
sys.path.append("./model/")
42+
from inference_ranking_gr import InferenceRankingGR
43+
44+
import modules.paged_hstu_infer_layer as pg
45+
from modules.paged_hstu_infer_layer import init
46+
47+
class RunningMode(enum.Enum):
48+
EVAL = "eval"
49+
SIMULATE = "simulate"
50+
51+
def __str__(self):
52+
return self.value
53+
54+
55+
def get_inference_dataset_and_embedding_configs():
56+
dataset_args = DatasetArgs()
57+
embedding_dim = NetworkArgs().hidden_size
58+
HASH_SIZE = 10_000_000
59+
if dataset_args.dataset_name == "kuairand-1k":
60+
embedding_configs = [
61+
InferenceEmbeddingConfig(
62+
feature_names=["user_id"],
63+
table_name="user_id",
64+
vocab_size=1000,
65+
dim=embedding_dim,
66+
use_dynamicemb=True,
67+
),
68+
InferenceEmbeddingConfig(
69+
feature_names=["user_active_degree"],
70+
table_name="user_active_degree",
71+
vocab_size=8,
72+
dim=embedding_dim,
73+
use_dynamicemb=False,
74+
),
75+
InferenceEmbeddingConfig(
76+
feature_names=["follow_user_num_range"],
77+
table_name="follow_user_num_range",
78+
vocab_size=9,
79+
dim=embedding_dim,
80+
use_dynamicemb=False,
81+
),
82+
InferenceEmbeddingConfig(
83+
feature_names=["fans_user_num_range"],
84+
table_name="fans_user_num_range",
85+
vocab_size=9,
86+
dim=embedding_dim,
87+
use_dynamicemb=False,
88+
),
89+
InferenceEmbeddingConfig(
90+
feature_names=["friend_user_num_range"],
91+
table_name="friend_user_num_range",
92+
vocab_size=8,
93+
dim=embedding_dim,
94+
use_dynamicemb=False,
95+
),
96+
InferenceEmbeddingConfig(
97+
feature_names=["register_days_range"],
98+
table_name="register_days_range",
99+
vocab_size=8,
100+
dim=embedding_dim,
101+
use_dynamicemb=False,
102+
),
103+
InferenceEmbeddingConfig(
104+
feature_names=["video_id"],
105+
table_name="video_id",
106+
vocab_size=HASH_SIZE,
107+
dim=embedding_dim,
108+
use_dynamicemb=True,
109+
),
110+
InferenceEmbeddingConfig(
111+
feature_names=["action_weights"],
112+
table_name="action_weights",
113+
vocab_size=233,
114+
dim=embedding_dim,
115+
use_dynamicemb=False,
116+
),
117+
]
118+
return dataset_args, embedding_configs
119+
120+
raise ValueError(f"dataset {dataset_args.dataset_name} is not supported")
121+
122+
123+
def get_inference_hstu_model(
124+
emb_configs,
125+
max_batch_size,
126+
num_contextual_features,
127+
total_max_seqlen,
128+
checkpoint_dir,
129+
):
130+
network_args = NetworkArgs()
131+
if network_args.dtype_str == "bfloat16":
132+
inference_dtype = torch.bfloat16
133+
# elif network_args.dtype_str == "float16":
134+
# inference_dtype = torch.float16
135+
else:
136+
raise ValueError(
137+
f"Inference data type {network_args.dtype_str} is not supported"
138+
)
139+
140+
position_encoding_config = PositionEncodingConfig(
141+
num_position_buckets=8192,
142+
num_time_buckets=2048,
143+
use_time_encoding=False,
144+
static_max_seq_len=math.ceil(total_max_seqlen / 32) * 32,
145+
)
146+
147+
hstu_config = get_inference_hstu_config(
148+
hidden_size=network_args.hidden_size,
149+
num_layers=network_args.num_layers,
150+
num_attention_heads=network_args.num_attention_heads,
151+
head_dim=network_args.kv_channels,
152+
dtype=inference_dtype,
153+
position_encoding_config=position_encoding_config,
154+
contextual_max_seqlen=num_contextual_features,
155+
scaling_seqlen=network_args.scaling_seqlen,
156+
)
157+
158+
kvcache_args = {
159+
"blocks_in_primary_pool": 10240,
160+
"page_size": 32,
161+
"offload_chunksize": 1024,
162+
"max_batch_size": max_batch_size,
163+
"max_seq_len": math.ceil(total_max_seqlen / 32) * 32,
164+
}
165+
kv_cache_config = get_kvcache_config(**kvcache_args)
166+
167+
ranking_args = RankingArgs()
168+
task_config = RankingConfig(
169+
embedding_configs=emb_configs,
170+
prediction_head_arch=ranking_args.prediction_head_arch,
171+
prediction_head_act_type=ranking_args.prediction_head_act_type,
172+
prediction_head_bias=ranking_args.prediction_head_bias,
173+
num_tasks=ranking_args.num_tasks,
174+
eval_metrics=ranking_args.eval_metrics,
175+
)
176+
177+
hstu_cudagraph_configs = {
178+
"batch_size": [1],
179+
"length_per_sequence": [128] + [i * 256 for i in range(1, 34)],
180+
}
181+
182+
model = InferenceRankingGR(
183+
hstu_config=hstu_config,
184+
kvcache_config=kv_cache_config,
185+
task_config=task_config,
186+
use_cudagraph=False,
187+
cudagraph_configs=hstu_cudagraph_configs,
188+
)
189+
if hstu_config.bf16:
190+
model.bfloat16()
191+
elif hstu_config.fp16:
192+
model.half()
193+
model.load_checkpoint(checkpoint_dir)
194+
model.eval()
195+
196+
return model
197+
198+
199+
def get_new_batch(
200+
batch, hist_lengths, ratio, num_contextuals
201+
):
202+
partial_lengths = torch.ceil(hist_lengths * ratio).long() - num_contextuals
203+
partial_lengths = partial_lengths // 2
204+
205+
kjt_dict = batch.features.to_dict()
206+
item_jt = kjt_dict["video_id"]
207+
vals = item_jt.values()
208+
lens = item_jt.lengths()
209+
num_candidates = batch.num_candidates
210+
split_lens = torch.stack(
211+
[partial_lengths + num_candidates, lens - partial_lengths - num_candidates], dim=1
212+
).reshape((-1,))
213+
stripped_vals = torch.split(vals, split_lens.tolist())[::2]
214+
kjt_dict["video_id"] = JaggedTensor.from_dense(stripped_vals)
215+
216+
action_jt = kjt_dict["action_weights"]
217+
vals = action_jt.values()
218+
lens = action_jt.lengths()
219+
split_lens = torch.stack(
220+
[partial_lengths, lens - partial_lengths], dim=1
221+
).reshape((-1,))
222+
stripped_vals = torch.split(vals, split_lens.tolist())[::2]
223+
kjt_dict["action_weights"] = JaggedTensor.from_dense(stripped_vals)
224+
225+
batch.features = KeyedJaggedTensor.from_jt_dict(kjt_dict)
226+
hist_lengths = num_contextuals + partial_lengths * 2
227+
228+
return batch, hist_lengths
229+
230+
231+
232+
def run_kvcache_consistency_check(
233+
checkpoint_dir: str,
234+
disable_kvcache: bool = False,
235+
):
236+
dataset_args, emb_configs = get_inference_dataset_and_embedding_configs()
237+
238+
dataproc = get_common_preprocessors("")[dataset_args.dataset_name]
239+
num_contextual_features = len(dataproc._contextual_feature_names)
240+
241+
max_batch_size = 1
242+
total_max_seqlen = dataset_args.max_sequence_length * 2 + num_contextual_features
243+
print("total_max_seqlen", total_max_seqlen)
244+
245+
def strip_candidate_action_tokens(batch, action_feature_name):
246+
kjt_dict = batch.features.to_dict()
247+
action_jagged_tensor = kjt_dict[action_feature_name]
248+
values = action_jagged_tensor.values()
249+
lengths = action_jagged_tensor.lengths()
250+
num_candidates = batch.num_candidates
251+
split_lengths = torch.stack(
252+
[lengths - num_candidates, num_candidates], dim=1
253+
).reshape((-1,))
254+
stripped_value = torch.split(values, split_lengths.tolist())[::2]
255+
kjt_dict[action_feature_name] = JaggedTensor.from_dense(stripped_value)
256+
batch.features = KeyedJaggedTensor.from_jt_dict(kjt_dict)
257+
return batch
258+
259+
def strip_padding_batch(batch, unpadded_batch_size):
260+
batch.batch_size = unpadded_batch_size
261+
kjt_dict = batch.features.to_dict()
262+
for k in kjt_dict:
263+
kjt_dict[k] = JaggedTensor.from_dense_lengths(
264+
kjt_dict[k].to_padded_dense()[: batch.batch_size],
265+
kjt_dict[k].lengths()[: batch.batch_size].long(),
266+
)
267+
batch.features = KeyedJaggedTensor.from_jt_dict(kjt_dict)
268+
batch.num_candidates = batch.num_candidates[: batch.batch_size]
269+
return batch
270+
271+
with torch.inference_mode():
272+
model = get_inference_hstu_model(
273+
emb_configs,
274+
max_batch_size,
275+
num_contextual_features,
276+
total_max_seqlen,
277+
checkpoint_dir,
278+
)
279+
280+
eval_module = get_multi_event_metric_module(
281+
num_classes=model._task_config.prediction_head_arch[-1],
282+
num_tasks=model._task_config.num_tasks,
283+
metric_types=model._task_config.eval_metrics,
284+
)
285+
286+
train_dataset, _ = get_dataset(
287+
dataset_name=dataset_args.dataset_name,
288+
dataset_path=dataset_args.dataset_path,
289+
max_sequence_length=dataset_args.max_sequence_length,
290+
max_num_candidates=dataset_args.max_num_candidates,
291+
num_tasks=model._task_config.num_tasks,
292+
batch_size=max_batch_size,
293+
rank=0,
294+
world_size=1,
295+
shuffle=False,
296+
random_seed=0,
297+
eval_batch_size=max_batch_size,
298+
)
299+
300+
dataloader = get_data_loader(dataset=train_dataset)
301+
302+
num_kvc_test_rounds = 2
303+
304+
# torch.cuda.memory._record_memory_history()
305+
# torch.cuda.profiler.start()
306+
for round_id in [0, 1]:
307+
dataloader_iter = iter(dataloader)
308+
309+
length_ratio = (round_id + 1) / num_kvc_test_rounds
310+
while True:
311+
try:
312+
batch = next(dataloader_iter)
313+
if model._task_config.num_tasks > 0:
314+
batch = strip_candidate_action_tokens(
315+
batch, dataproc._action_feature_name
316+
)
317+
318+
batch = batch.to(device=torch.cuda.current_device())
319+
320+
d = batch.features.to_dict()
321+
user_ids = d["user_id"].values().cpu().long()
322+
if user_ids.shape[0] != batch.batch_size:
323+
batch = strip_padding_batch(batch, user_ids.shape[0])
324+
total_history_lengths = torch.sum(batch.features.lengths().view(-1, batch.batch_size), 0).view(-1) - batch.num_candidates
325+
326+
if round_id != num_kvc_test_rounds - 1:
327+
batch, total_history_lengths = get_new_batch(batch, total_history_lengths, length_ratio, num_contextual_features)
328+
329+
# if int(user_ids[0]) == 0:
330+
# pg.dmp = True
331+
if not disable_kvcache:
332+
logits = model.forward(batch, user_ids, total_history_lengths.cpu())
333+
else:
334+
logits = model.forward_nokvcache(batch)
335+
336+
if pg.dmp:
337+
if disable_kvcache:
338+
for lidx in range(model._hstu_config.num_layers):
339+
if user_ids[0] < 10 or user_ids[0] >= 690:
340+
shutil.move(f"/tmp/in_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_in_l{lidx}.npy")
341+
shutil.move(f"/tmp/key_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_key_l{lidx}.npy")
342+
shutil.move(f"/tmp/value_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_value_l{lidx}.npy")
343+
shutil.move(f"/tmp/attn_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_attn_l{lidx}.npy")
344+
shutil.move(f"/tmp/out_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_out_l{lidx}.npy")
345+
346+
else:
347+
os.remove(f"/tmp/key_l{lidx}.npy")
348+
os.remove(f"/tmp/value_l{lidx}.npy")
349+
else:
350+
for lidx in range(model._hstu_config.num_layers):
351+
if user_ids[0] < 10 or user_ids[0] >= 690:
352+
shutil.move(f"/tmp/in_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_in_l{lidx}.npy")
353+
shutil.move(f"/tmp/key_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_key_l{lidx}.npy")
354+
shutil.move(f"/tmp/value_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_value_l{lidx}.npy")
355+
shutil.move(f"/tmp/attn_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_attn_l{lidx}.npy")
356+
shutil.move(f"/tmp/out_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_out_l{lidx}.npy")
357+
else:
358+
os.remove(f"/tmp/key_l{lidx}.npy")
359+
os.remove(f"/tmp/value_l{lidx}.npy")
360+
pg.dmp = False
361+
362+
if round_id == num_kvc_test_rounds - 1:
363+
eval_module(logits, batch.labels)
364+
except StopIteration:
365+
break
366+
# torch.cuda.profiler.stop()
367+
# torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
368+
369+
eval_metric_dict = eval_module.compute()
370+
print(
371+
f"[eval]:\n "
372+
+ stringify_dict(eval_metric_dict, prefix="Metrics", sep="\n ")
373+
)
374+
# print("X")
375+
376+
if __name__ == "__main__":
377+
init()
378+
parser = argparse.ArgumentParser(description="Inference End-to-end Example")
379+
parser.add_argument("--gin_config_file", type=str, required=True)
380+
parser.add_argument("--checkpoint_dir", type=str, required=True)
381+
parser.add_argument("--disable_kvcache", action="store_true")
382+
# parser.add_argument("--max_bs", type=int, required=True)
383+
384+
385+
args = parser.parse_args()
386+
gin.parse_config_file(args.gin_config_file)
387+
388+
run_kvcache_consistency_check(
389+
checkpoint_dir=args.checkpoint_dir,
390+
disable_kvcache=args.disable_kvcache,
391+
)
392+
print("Finished.")

0 commit comments

Comments
 (0)