Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jenkins/L0_MergeRequest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
"tensorrt_llm/_torch/pyexecutor/_util.py",
"tensorrt_llm/_torch/pyexecutor/model_engine.py",
"tensorrt_llm/_torch/pyexecutor/py_executor.py",
"tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py",
"tensorrt_llm/evaluate/json_mode_eval.py",
"tensorrt_llm/evaluate/mmlu.py",
"tensorrt_llm/executor/",
Expand All @@ -740,6 +741,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
"tests/integration/defs/accuracy/test_disaggregated_serving.py",
"tests/unittest/_torch/ray_orchestrator/multi_gpu/",
"tests/integration/defs/examples/test_ray.py",
"tests/integration/defs/accuracy/test_llm_api_autodeploy.py",
"tests/unittest/llmapi/test_async_llm.py",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,12 @@ def __init__(self, config, layer_idx: int):
# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
# Instead of recomputing `torch.exp(self.A_log.float())` on every forward pass, we will register a hook
# self.A_log = nn.Parameter(torch.log(A))
# self.A_log._no_weight_decay = True
# Instead of recomputing `-torch.exp(self.A_log.float())` on every forward pass, we will register a hook
# that sets this appropriately when loading weights.
# NOTE: we explicitly register this as a non-persistent buffer so that it does not appear in the state dict of
# this module, or an equivalent graph module trace from it, but still gets included in e.g. `to()` calls.
self.register_buffer("_minus_A", -A.float(), persistent=False)
self.A_minus = nn.Parameter(-A.float())
self.A_minus._no_weight_decay = True
self.norm = MambaRMSNormGated(
self.intermediate_size,
eps=self.layer_norm_epsilon,
Expand All @@ -129,8 +128,6 @@ def __init__(self, config, layer_idx: int):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias

self.register_load_state_dict_post_hook(self._load_state_dict_post_hook)

def torch_forward(self, input_states):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
Expand Down Expand Up @@ -166,10 +163,9 @@ def torch_forward(self, input_states):
)

# 3. SSM transformation
A = self._minus_A
y = torch.ops.auto_deploy.torch_ssm(
hidden_states=hidden_states.view(batch_size, seq_len, -1, self.head_dim),
A=A,
A=self.A_minus,
B=B.view(batch_size, seq_len, -1, self.ssm_state_size),
C=C.view(batch_size, seq_len, -1, self.ssm_state_size),
D=self.D,
Expand All @@ -193,10 +189,6 @@ def torch_forward(self, input_states):
def forward(self, hidden_states):
return self.torch_forward(hidden_states)

@staticmethod
def _load_state_dict_post_hook(module, incompatible_keys) -> None:
module._minus_A.data = -torch.exp(module.A_log.float())


class NemotronHRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
Expand Down Expand Up @@ -466,7 +458,7 @@ class NemotronHPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, NemotronHMamba2Mixer):
module.A_log._no_weight_decay = True
module.A_minus._no_weight_decay = True
module.D._no_weight_decay = True

dt = torch.exp(
Expand Down Expand Up @@ -590,6 +582,13 @@ def __init__(self, config):
self.backbone = NemotronHModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Recursively iterate over all modules in self.backbone and list those with A_minus or A_log in their name
self.backbone_modules_with_A = []
for module_name, module in self.backbone.named_modules():
for param_name, _ in module.named_parameters(recurse=False):
if param_name in ("A_minus", "A_log"):
self.register_load_state_dict_pre_hook(self._a_log_pre_hook)
self.backbone_modules_with_A.append((module_name, param_name))
Comment on lines +585 to +591
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only need to register this load state dict pre hook once since you already iterate over all the keys below:

Suggested change
# Recursively iterate over all modules in self.backbone and list those with A_minus or A_log in their name
self.backbone_modules_with_A = []
for module_name, module in self.backbone.named_modules():
for param_name, _ in module.named_parameters(recurse=False):
if param_name in ("A_minus", "A_log"):
self.register_load_state_dict_pre_hook(self._a_log_pre_hook)
self.backbone_modules_with_A.append((module_name, param_name))
self.register_load_state_dict_pre_hook(self._a_log_pre_hook)


# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -620,5 +619,23 @@ def forward(

return NemotronHCausalLMOutput(logits)

@staticmethod
def _a_log_pre_hook(
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
) -> None:
all_keys = list(state_dict.keys())
for key in all_keys:
if "A_log" in key:
A_log_key = key
A_minus_key = key.replace("A_log", "A_minus")
state_dict[A_minus_key] = -torch.exp(state_dict.pop(A_log_key).float())


AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM)
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ...shim.interface import CachedSequenceInterface
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op
from ...utils.node_utils import extract_weight_name, is_linear_op, is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry


Expand All @@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
y2 = y[:, out1:out1+out2]
"""
# some info we need
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
keys_unfused = [extract_weight_name(n) for n in linear_nodes]
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
sizes_unfused = [p.size(0) for p in params_unfused]
key_fused = f"fused_weight_{idx}"
Expand Down Expand Up @@ -128,7 +128,7 @@ def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple
def _insert_fused_quant_gemm(
self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]
):
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
keys_unfused = [extract_weight_name(n) for n in linear_nodes]
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
sizes_unfused = [p.size(0) for p in params_unfused]
key_fused = f"fused_weight_{idx}"
Expand Down
22 changes: 12 additions & 10 deletions tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import (
extract_param_names_from_node,
extract_weight_nodes,
get_quantization_params_from_linear_node,
is_bmm_op,
is_linear_op,
Expand Down Expand Up @@ -139,13 +139,13 @@ def _insert_quantized_linear(

The state_dict is also updated to contain the sharded weights.
"""
param_name, _ = extract_param_names_from_node(node)
original_weight = gm.get_parameter(param_name)
new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False)
modname, _, attrname = param_name.rpartition(".")
weight_nodes = extract_weight_nodes(node)
assert len(weight_nodes.weights) == 1, "Expected exactly one weight node"
lin_weight = weight_nodes.weights[0]
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
modname, _, attrname = lin_weight.node_key.rpartition(".")

submod = gm.get_submodule(modname)
setattr(submod, attrname, new_param)
setattr(lin_weight.submod, attrname, new_param)

# check modelopt quantizers from graph
if is_quantized_graph:
Expand All @@ -171,10 +171,12 @@ def _insert_quantized_linear(
)
# Note: canonicalize_graph() will remove input/weight/output quantizer

for scale_name, scale in self.default_scales(original_weight.shape).items():
submod.register_buffer(scale_name, scale)
for scale_name, scale in self.default_scales(lin_weight.tensor.shape).items():
lin_weight.submod.register_buffer(scale_name, scale)

gm._register_load_state_dict_pre_hook(partial(self.load_hook, weight_name=param_name))
gm._register_load_state_dict_pre_hook(
partial(self.load_hook, weight_name=lin_weight.node_key)
)

with gm.graph.inserting_before(node):
scales = {}
Expand Down
Loading