|
1146 | 1146 | MissConfig: "miss_", |
1147 | 1147 | TrainableTokensConfig: "trainable_tokens_", |
1148 | 1148 | WaveFTConfig: "waveft_", |
| 1149 | + OSFConfig: "osf_", |
1149 | 1150 | } |
1150 | 1151 |
|
1151 | 1152 |
|
@@ -1829,9 +1830,7 @@ def test_forward_float16(self, test_name, model_id, config_cls, config_kwargs): |
1829 | 1830 | # check that none of this raises an error |
1830 | 1831 | model(**X) |
1831 | 1832 |
|
1832 | | - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: |
1833 | | - # this model does not support merging |
1834 | | - return |
| 1833 | + _skip_if_merging_not_supported(model_id, config_cls) |
1835 | 1834 |
|
1836 | 1835 | model.merge_adapter(safe_merge=False) |
1837 | 1836 | model(**X) |
@@ -1871,9 +1870,7 @@ def test_forward_bfloat16(self, test_name, model_id, config_cls, config_kwargs): |
1871 | 1870 | # check that none of this raises an error |
1872 | 1871 | model(**X) |
1873 | 1872 |
|
1874 | | - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: |
1875 | | - # this model does not support merging |
1876 | | - return |
| 1873 | + _skip_if_merging_not_supported(model_id, config_cls) |
1877 | 1874 |
|
1878 | 1875 | model.merge_adapter(safe_merge=False) |
1879 | 1876 | model(**X) |
@@ -1912,9 +1909,7 @@ def test_forward_float16_no_autocast(self, test_name, model_id, config_cls, conf |
1912 | 1909 | # check that none of this raises an error |
1913 | 1910 | model(**X) |
1914 | 1911 |
|
1915 | | - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: |
1916 | | - # this model does not support merging |
1917 | | - return |
| 1912 | + _skip_if_merging_not_supported(model_id, config_cls) |
1918 | 1913 |
|
1919 | 1914 | model.merge_adapter(safe_merge=False) |
1920 | 1915 | model(**X) |
@@ -1953,9 +1948,7 @@ def test_forward_bfloat16_no_autocast(self, test_name, model_id, config_cls, con |
1953 | 1948 | # check that none of this raises an error |
1954 | 1949 | model(**X) |
1955 | 1950 |
|
1956 | | - if model_id in ["Conv2dGroups", "Conv2dGroups2"]: |
1957 | | - # this model does not support merging |
1958 | | - return |
| 1951 | + _skip_if_merging_not_supported(model_id, config_cls) |
1959 | 1952 |
|
1960 | 1953 | model.merge_adapter(safe_merge=False) |
1961 | 1954 | model(**X) |
@@ -2032,7 +2025,7 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c |
2032 | 2025 | lr = 0.1 # otherwise we get nan |
2033 | 2026 | elif "mha" in model_id.lower(): |
2034 | 2027 | lr = 1e-3 # we get exploding gradients with MHA when learning rate is too high |
2035 | | - elif issubclass(config_cls, VBLoRAConfig) or issubclass(config_cls, RandLoraConfig): |
| 2028 | + elif issubclass(config_cls, (VBLoRAConfig, RandLoraConfig, OSFConfig)): |
2036 | 2029 | lr = 0.01 # otherwise we get nan |
2037 | 2030 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) |
2038 | 2031 |
|
@@ -2083,7 +2076,11 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): |
2083 | 2076 | torch.nn.init.zeros_(model.vblora_vector_bank["default"]) |
2084 | 2077 | model.eval() |
2085 | 2078 | outputs_before = model(**X) |
2086 | | - assert torch.allclose(outputs_base, outputs_before) |
| 2079 | + # OSF uses SVD reconstruction which introduces small numerical differences |
| 2080 | + if issubclass(config_cls, OSFConfig): |
| 2081 | + assert torch.allclose(outputs_base, outputs_before, rtol=1e-4, atol=1e-4) |
| 2082 | + else: |
| 2083 | + assert torch.allclose(outputs_base, outputs_before) |
2087 | 2084 |
|
2088 | 2085 | if issubclass(config_cls, VBLoRAConfig): |
2089 | 2086 | # initialize `vblora_vector_bank` so it can be trained |
@@ -2121,7 +2118,11 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): |
2121 | 2118 | else: |
2122 | 2119 | rtol, atol = 1e-5, 1e-8 |
2123 | 2120 | assert not torch.allclose(outputs_before, outputs_after, rtol=rtol, atol=atol) |
2124 | | - assert torch.allclose(outputs_before, outputs_disabled) |
| 2121 | + # OSF uses SVD reconstruction which introduces small numerical differences |
| 2122 | + if issubclass(config_cls, OSFConfig): |
| 2123 | + assert torch.allclose(outputs_before, outputs_disabled, rtol=1e-4, atol=1e-4) |
| 2124 | + else: |
| 2125 | + assert torch.allclose(outputs_before, outputs_disabled) |
2125 | 2126 | assert torch.allclose(outputs_after, outputs_enabled_after_disable) |
2126 | 2127 |
|
2127 | 2128 | @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) |
|
0 commit comments