-
Notifications
You must be signed in to change notification settings - Fork 178
refactor -> select_model(functional) #468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,13 +3,11 @@ | |
|
|
||
| import json | ||
| from pathlib import Path | ||
| from typing import Union | ||
| from typing import Union, TYPE_CHECKING | ||
|
|
||
|
|
||
| from infinity_emb.args import ( | ||
| EngineArgs, | ||
| ) | ||
| from infinity_emb.log_handler import logger | ||
| from infinity_emb.transformer.abstract import BaseCrossEncoder, BaseEmbedder | ||
| from functools import partial | ||
| from infinity_emb.transformer.utils import ( | ||
| AudioEmbedEngine, | ||
| EmbedderEngine, | ||
|
|
@@ -19,9 +17,15 @@ | |
| RerankEngine, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from infinity_emb.transformer.abstract import BaseTypeHint # , CallableReturningBaseTypeHint | ||
| from infinity_emb.args import ( | ||
| EngineArgs, | ||
| ) | ||
|
|
||
|
|
||
| def get_engine_type_from_config( | ||
| engine_args: EngineArgs, | ||
| engine_args: "EngineArgs", | ||
| ) -> Union[EmbedderEngine, RerankEngine, PredictEngine, ImageEmbedEngine, AudioEmbedEngine]: | ||
| """resolved the class of inference engine path from config.json of the repo.""" | ||
| if engine_args.engine in [InferenceEngine.debugengine]: | ||
|
|
@@ -57,55 +61,51 @@ def get_engine_type_from_config( | |
| return EmbedderEngine.from_inference_engine(engine_args.engine) | ||
|
|
||
|
|
||
| def _get_engine_replica(unloaded_engine, engine_args, device_map) -> "BaseTypeHint": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: function lacks type hints for unloaded_engine and device_map parameters |
||
| engine_args_copy = engine_args.copy() | ||
| engine_args_copy._loading_strategy.device_placement = device_map | ||
| loaded_engine = unloaded_engine.value(engine_args=engine_args_copy) | ||
|
|
||
| if engine_args.model_warmup: | ||
| # size one, warm up warm start timings. | ||
| # loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=1) | ||
| # size one token | ||
| min(loaded_engine.warmup(batch_size=1, n_tokens=1)[1] for _ in range(5)) | ||
michaelfeil marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| emb_per_sec_short, max_inference_temp, log_msg = loaded_engine.warmup( | ||
| batch_size=engine_args.batch_size, n_tokens=1 | ||
| ) | ||
|
|
||
| logger.info(log_msg) | ||
| # now warm up with max_token, max batch size | ||
| loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=512) | ||
| emb_per_sec, _, log_msg = loaded_engine.warmup( | ||
| batch_size=engine_args.batch_size, n_tokens=512 | ||
| ) | ||
| logger.info(log_msg) | ||
| logger.info( | ||
| f"model warmed up, between {emb_per_sec:.2f}-{emb_per_sec_short:.2f}" | ||
| f" embeddings/sec at batch_size={engine_args.batch_size}" | ||
| ) | ||
| return loaded_engine | ||
|
|
||
|
|
||
| def select_model( | ||
| engine_args: EngineArgs, | ||
| ) -> tuple[list[Union[BaseCrossEncoder, BaseEmbedder]], float, float]: | ||
| engine_args: "EngineArgs", | ||
| ) -> list[partial["BaseTypeHint"]]: | ||
| """based on engine args, fully instantiates the Engine.""" | ||
| logger.info( | ||
| f"model=`{engine_args.model_name_or_path}` selected, " | ||
| f"using engine=`{engine_args.engine.value}`" | ||
| f" and device=`{engine_args.device.resolve()}`" | ||
| ) | ||
| # engine_args.update_loading_strategy() | ||
|
|
||
| unloaded_engine = get_engine_type_from_config(engine_args) | ||
|
|
||
| engine_replicas = [] | ||
| min_inference_t = 4e-3 | ||
| max_inference_t = 4e-3 | ||
|
|
||
| # TODO: Can be parallelized | ||
| for device_map in engine_args._loading_strategy.device_mapping: # type: ignore | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: type: ignore on device_mapping access should be replaced with proper type annotation |
||
| engine_args_copy = engine_args.copy() | ||
| engine_args_copy._loading_strategy.device_placement = device_map | ||
| loaded_engine = unloaded_engine.value(engine_args=engine_args_copy) | ||
|
|
||
| if engine_args.model_warmup: | ||
| # size one, warm up warm start timings. | ||
| # loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=1) | ||
| # size one token | ||
| min_inference_t = min( | ||
| min(loaded_engine.warmup(batch_size=1, n_tokens=1)[1] for _ in range(10)), | ||
| min_inference_t, | ||
| ) | ||
| loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=1) | ||
| emb_per_sec_short, max_inference_temp, log_msg = loaded_engine.warmup( | ||
| batch_size=engine_args.batch_size, n_tokens=1 | ||
| ) | ||
| max_inference_t = max(max_inference_temp, max_inference_t) | ||
|
|
||
| logger.info(log_msg) | ||
| # now warm up with max_token, max batch size | ||
| loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=512) | ||
| emb_per_sec, _, log_msg = loaded_engine.warmup( | ||
| batch_size=engine_args.batch_size, n_tokens=512 | ||
| ) | ||
| logger.info(log_msg) | ||
| logger.info( | ||
| f"model warmed up, between {emb_per_sec:.2f}-{emb_per_sec_short:.2f}" | ||
| f" embeddings/sec at batch_size={engine_args.batch_size}" | ||
| ) | ||
| engine_replicas.append(loaded_engine) | ||
| engine_replicas.append( | ||
| partial(_get_engine_replica, unloaded_engine, engine_args, device_map) | ||
| ) | ||
| assert len(engine_replicas) > 0, "No engine replicas were loaded" | ||
|
|
||
| return engine_replicas, min_inference_t, max_inference_t | ||
| return engine_replicas | ||
Uh oh!
There was an error while loading. Please reload this page.