-
Notifications
You must be signed in to change notification settings - Fork 19.6k
fix quantization save and load error #21504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1359,10 +1359,32 @@ def save_own_variables(self, store): | |
Args: | ||
store: Dict where the state of the model will be saved. | ||
""" | ||
if not getattr(self, "_is_quantized", False): | ||
all_vars = self._trainable_variables + self._non_trainable_variables | ||
for i, v in enumerate(all_vars): | ||
store[f"{i}"] = v | ||
return | ||
|
||
# Case: quantized layer | ||
quantized_vars = self._get_quantized_variables() | ||
for i, v in enumerate(quantized_vars): | ||
store[f"quantized_{i}"] = v | ||
|
||
# Save non-quantized variables | ||
all_vars = self._trainable_variables + self._non_trainable_variables | ||
for i, v in enumerate(all_vars): | ||
non_quantized_vars = [ | ||
v for v in all_vars if v not in quantized_vars and v.trainable | ||
] | ||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for i, v in enumerate(non_quantized_vars): | ||
store[f"{i}"] = v | ||
|
||
def _get_quantized_variables(self): | ||
quantized_vars = [] | ||
for v in self._trainable_variables + self._non_trainable_variables: | ||
if not backend.is_float_dtype(v.dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes that all integral variables come from quantization. But what if you have variables that intrinsically represent ints and are unrelated to quantization? We definitely have layers using int vars. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed this to check for |
||
quantized_vars.append(v) | ||
return quantized_vars | ||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def load_own_variables(self, store): | ||
"""Loads the state of the layer. | ||
|
||
|
@@ -1372,6 +1394,10 @@ def load_own_variables(self, store): | |
Args: | ||
store: Dict from which the state of the model will be loaded. | ||
""" | ||
if any(key.startswith("quantized_") for key in store.keys()): | ||
self._load_quantized_variables(store) | ||
return | ||
|
||
all_vars = self._trainable_variables + self._non_trainable_variables | ||
if len(store.keys()) != len(all_vars): | ||
if len(all_vars) == 0 and not self.built: | ||
|
@@ -1407,6 +1433,19 @@ def load_own_variables(self, store): | |
for i, v in enumerate(all_vars): | ||
v.assign(store[f"{i}"]) | ||
|
||
def _load_quantized_variables(self, store): | ||
quantized_vars = self._get_quantized_variables() | ||
for i, v in enumerate(quantized_vars): | ||
v.assign(store[f"quantized_{i}"]) | ||
|
||
# Load non-quantized variables | ||
all_vars = self._trainable_variables + self._non_trainable_variables | ||
non_quantized_vars = [ | ||
v for v in all_vars if v not in quantized_vars and v.trainable | ||
] | ||
for i, v in enumerate(non_quantized_vars): | ||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
v.assign(store[f"{i}"]) | ||
|
||
def _track_variable(self, variable): | ||
if variable.trainable: | ||
self._tracker.add_to_store("trainable_variables", variable) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you invert the
if
?If getattr(self, "_is_quantized", False):
... (more readable)