1+ # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ # SPDX-License-Identifier: Apache-2.0
3+ #
4+ # Licensed under the Apache License, Version 2.0 (the "License");
5+ # you may not use this file except in compliance with the License.
6+ # You may obtain a copy of the License at
7+ #
8+ # http://www.apache.org/licenses/LICENSE-2.0
9+ #
10+ # Unless required by applicable law or agreed to in writing, software
11+ # distributed under the License is distributed on an "AS IS" BASIS,
12+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ # See the License for the specific language governing permissions and
14+ # limitations under the License.
15+
116import base64
217from typing import Optional
318import inspect
@@ -14,9 +29,8 @@ class WorkerExtension:
1429 def __init__ (self ):
1530 pass
1631
17- @control_action_decorator
18- def supports_partial_loading (self ) -> bool :
19- """Check if the model supports partial weight loading."""
32+ def _check_partial_loading_support (self ) -> bool :
33+ """Private helper to check if the model supports partial weight loading."""
2034 try :
2135 model = self .engine .model_engine .model
2236 load_weights_args = inspect .getfullargspec (model .load_weights ).args
@@ -25,6 +39,11 @@ def supports_partial_loading(self) -> bool:
2539 logger .warning (f"Failed to check partial loading support: { e } " )
2640 return False
2741
42+ @control_action_decorator
43+ def supports_partial_loading (self ) -> bool :
44+ """Check if the model supports partial weight loading."""
45+ return self ._check_partial_loading_support ()
46+
2847 @control_action_decorator
2948 def update_weights (self , ipc_handles : Optional [dict ] = None ):
3049 try :
@@ -72,7 +91,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
7291
7392 # Verify the result is a list as expected
7493 if not isinstance (all_handles , list ):
75- raise ValueError (f"Deserialized data must be a list, got { type (all_handles ).__name__ } instead" )
94+ raise TypeError (f"Deserialized data must be a list, got { type (all_handles ).__name__ } instead" )
7695 else :
7796 # Data is already in the correct format (backward compatibility)
7897 all_handles = serialized_handles
@@ -88,8 +107,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
88107
89108 # Check if model supports partial loading and use appropriate strategy
90109 model = self .engine .model_engine .model
91- load_weights_args = inspect .getfullargspec (model .load_weights ).args
92- supports_partial_loading = "allow_partial_loading" in load_weights_args
110+ supports_partial_loading = self ._check_partial_loading_support ()
93111
94112 if supports_partial_loading :
95113 self .engine .model_engine .model_loader .reload (model , weights , allow_partial_loading = True )
0 commit comments