|
9 | 9 | from sklearn.utils import all_estimators |
10 | 10 | from typing_extensions import Self |
11 | 11 |
|
12 | | -from autointent.context import Context |
13 | | -from autointent.context.embedder import Embedder |
| 12 | +from autointent.context import Context, Embedder |
14 | 13 | from autointent.context.vector_index_client import VectorIndexClient |
15 | 14 | from autointent.custom_types import BaseMetadataDict, LabelType |
16 | 15 | from autointent.modules.scoring._base import ScoringModule |
@@ -109,21 +108,17 @@ def from_context( |
109 | 108 | precomputed_embeddings = True |
110 | 109 | else: |
111 | 110 | precomputed_embeddings = context.vector_index_client.exists(embedder_name) |
112 | | - context.device = context.get_device() |
113 | | - context.embedder_batch_size = context.get_batch_size() |
114 | | - context.embedder_max_length = context.get_max_length() |
115 | | - context.db_dir = context.get_db_dir() |
116 | 111 | instance = cls( |
117 | | - model_name=embedder_name, |
118 | | - device=context.device, |
| 112 | + embedder_name=embedder_name, |
| 113 | + device=context.get_device(), |
119 | 114 | seed=context.seed, |
120 | | - batch_size=context.embedder_batch_size, |
121 | | - max_length=context.embedder_max_length, |
| 115 | + batch_size=context.get_batch_size(), |
| 116 | + max_length=context.get_max_length(), |
122 | 117 | clf_name=clf_name, |
123 | 118 | clf_args=clf_args, |
124 | 119 | ) |
125 | 120 | instance.precomputed_embeddings = precomputed_embeddings |
126 | | - instance.db_dir = str(context.db_dir) |
| 121 | + instance.db_dir = str(context.get_db_dir()) |
127 | 122 | return instance |
128 | 123 |
|
129 | 124 | def fit( |
|
0 commit comments