Skip to content

Commit 6577783

Browse files
committed
Update based on codarabbita feedback
Signed-off-by: dingruiyi <dingruiyi@163.com>
1 parent 04d493d commit 6577783

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

tensorrt_llm/llmapi/rlhf_utils.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
import base64
217
from typing import Optional
318
import inspect
@@ -14,9 +29,8 @@ class WorkerExtension:
1429
def __init__(self):
1530
pass
1631

17-
@control_action_decorator
18-
def supports_partial_loading(self) -> bool:
19-
"""Check if the model supports partial weight loading."""
32+
def _check_partial_loading_support(self) -> bool:
33+
"""Private helper to check if the model supports partial weight loading."""
2034
try:
2135
model = self.engine.model_engine.model
2236
load_weights_args = inspect.getfullargspec(model.load_weights).args
@@ -25,6 +39,11 @@ def supports_partial_loading(self) -> bool:
2539
logger.warning(f"Failed to check partial loading support: {e}")
2640
return False
2741

42+
@control_action_decorator
43+
def supports_partial_loading(self) -> bool:
44+
"""Check if the model supports partial weight loading."""
45+
return self._check_partial_loading_support()
46+
2847
@control_action_decorator
2948
def update_weights(self, ipc_handles: Optional[dict] = None):
3049
try:
@@ -72,7 +91,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
7291

7392
# Verify the result is a list as expected
7493
if not isinstance(all_handles, list):
75-
raise ValueError(f"Deserialized data must be a list, got {type(all_handles).__name__} instead")
94+
raise TypeError(f"Deserialized data must be a list, got {type(all_handles).__name__} instead")
7695
else:
7796
# Data is already in the correct format (backward compatibility)
7897
all_handles = serialized_handles
@@ -88,8 +107,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
88107

89108
# Check if model supports partial loading and use appropriate strategy
90109
model = self.engine.model_engine.model
91-
load_weights_args = inspect.getfullargspec(model.load_weights).args
92-
supports_partial_loading = "allow_partial_loading" in load_weights_args
110+
supports_partial_loading = self._check_partial_loading_support()
93111

94112
if supports_partial_loading:
95113
self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True)

0 commit comments

Comments
 (0)