-
Notifications
You must be signed in to change notification settings - Fork 11
feat/train-embeddings #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 19 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
4caced7
add train
k0lenk4 c5b5b2c
fixed env
k0lenk4 e13f171
deleted kwargs and local savings, added config
k0lenk4 d26eda0
added test for train method
k0lenk4 ee1b4a1
add EmbedderFineTuningConfig to __init__
k0lenk4 3052628
correct __init__ in config, remov pytest in test file
k0lenk4 a96c27d
correct some syntax isues
k0lenk4 bdc1161
move batch_size to EmbedderFineTuningConfig
k0lenk4 d71de34
add __init__.py to /test/embedder
k0lenk4 941c13a
Remove whitespace from blank line
k0lenk4 3ecdc60
Merge remote-tracking branch 'origin/dev' into feat/train-embeddings
k0lenk4 0739413
correct errors
k0lenk4 1e161c6
the number of epochs and train objects have been increased
k0lenk4 e67f1bc
made lint
k0lenk4 3c38ec8
add early stopping
k0lenk4 c743c0b
remake train args
k0lenk4 71bf957
make a list of callbacks
k0lenk4 2963a4c
inline type annotation of variable "callback"
k0lenk4 03e4c59
change save_strategy to "epoch"
k0lenk4 714f910
default value of fp16 changed to False
k0lenk4 cb9b2ea
pull dev
voorhs 6aa7abc
integrate embeddings fine-tuning into Embedding modules
voorhs b88a810
pull dev
voorhs 2970737
Update optimizer_config.schema.json
github-actions[bot] fcf1f31
clean up `freeze` throughout tests and tutorials
voorhs 0b0c1fa
add comprehensive tests
voorhs 19a74f4
embedder_model -> _model
voorhs 1d73af6
fix early stopping
voorhs 714c8c2
fix tests
voorhs ebf066b
clear ram bug fix
voorhs 51a9b1a
try to fix windows cleanup issue
voorhs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -182,3 +182,5 @@ vector_db* | |
| *.db | ||
| *.sqlite | ||
| /wandb | ||
| model_output/ | ||
| my.py | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| import numpy as np | ||
|
|
||
| from autointent._wrappers.embedder import Embedder | ||
| from autointent.configs._transformers import EmbedderConfig, EmbedderFineTuningConfig, HFModelConfig | ||
| from autointent.context.data_handler import DataHandler | ||
|
|
||
|
|
||
| def test_model_updates_after_training(dataset): | ||
| """Test that model weights actually change after training""" | ||
| data_handler = DataHandler(dataset) | ||
|
|
||
| hf_config = HFModelConfig(model_name="intfloat/multilingual-e5-small", batch_size=8, trust_remote_code=True) | ||
|
|
||
| embedder_config = EmbedderConfig( | ||
| **hf_config.model_dump(), | ||
| default_prompt="Represent this text for retrieval:", | ||
| query_prompt="Search query:", | ||
| passage_prompt="Document:", | ||
| similarity_fn_name="cosine", | ||
| use_cache=False, | ||
| freeze=False, | ||
| ) | ||
|
|
||
| train_config = EmbedderFineTuningConfig(epoch_num=3, batch_size=8) | ||
| embedder = Embedder(embedder_config) | ||
| embedder._load_model() | ||
|
|
||
| for param in embedder.embedding_model.parameters(): | ||
| assert param.requires_grad, "All trainable parameters should have requires_grad=True" | ||
|
|
||
| original_weights = [ | ||
| param.data.detach().cpu().numpy().copy() | ||
| for param in embedder.embedding_model.parameters() | ||
| if param.requires_grad | ||
| ] | ||
| embedder.train( | ||
| utterances=data_handler.train_utterances(0)[:1000], | ||
| labels=data_handler.train_labels(0)[:1000], | ||
| config=train_config, | ||
| ) | ||
|
|
||
| trained_weights = [ | ||
| param.data.detach().cpu().numpy().copy() | ||
| for param in embedder.embedding_model.parameters() | ||
| if param.requires_grad | ||
| ] | ||
|
|
||
| weights_changed = any( | ||
| not np.allclose(orig, trained, atol=1e-6) | ||
| for orig, trained in zip(original_weights, trained_weights, strict=True) | ||
| ) | ||
| assert weights_changed, "Model weights should change after training" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.