Skip to content

Commit 533c4cb

Browse files
authored
fix(jax): setattr case_embd (#5104)
`case_embd` was supported but the JAX backend was not touched. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Bug Fixes** * Improved parameter handling in the fitting module to properly support the `case_embd` parameter, ensuring it receives consistent treatment and array conversion as other fitting parameters. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 26013cb commit 533c4cb

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

deepmd/jax/fitting/fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
4343
"fparam_inv_std",
4444
"aparam_avg",
4545
"aparam_inv_std",
46+
"case_embd",
4647
"default_fparam_tensor",
4748
}:
4849
value = to_jax_array(value)

source/tests/array_api_strict/fitting/fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
3131
"fparam_inv_std",
3232
"aparam_avg",
3333
"aparam_inv_std",
34+
"case_embd",
3435
"default_fparam_tensor",
3536
}:
3637
value = to_array_api_strict_array(value)

0 commit comments

Comments
 (0)