-
Notifications
You must be signed in to change notification settings - Fork 295
fix get_vocab_size for multimodal #984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,9 @@ def get_model_architectures(model_path: str): | |
| def get_vocab_size(model_path: str): | ||
| try: | ||
| config_json = get_config_json(model_path) | ||
| if "llm_config" in config_json: | ||
| vocab_size = int(config_json["llm_config"]["vocab_size"]) | ||
| return vocab_size | ||
|
Comment on lines
+47
to
+49
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This new block of code duplicates the logic for extracting and casting config_json = get_config_json(model_path)
# Select the right config dictionary
if "llm_config" in config_json:
config_json = config_json["llm_config"]
# Extract vocab_size from the selected config
vocab_size = config_json["vocab_size"]
if not isinstance(vocab_size, int):
vocab_size = int(vocab_size)
return vocab_size |
||
| vocab_size = config_json["vocab_size"] | ||
| if not isinstance(vocab_size, int): | ||
| vocab_size = int(vocab_size) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
epsvalue is hardcoded as1e-10in the function call, but the function already receives anepsparameter. It's better to use the provided parameter to make the function more flexible and to respect the function's contract. Also, theFalseargument is a "magic value" which makes the code harder to understand without context. Please consider adding a comment explaining its purpose, or use a named argument if thesgl_opsAPI supports it.