|
1 | 1 | import asyncio |
2 | | -import datetime |
3 | 2 | import logging |
4 | | -import os |
5 | 3 | from pathlib import Path |
6 | | -from typing import Any, Dict, Literal, Optional, Union |
7 | | - |
8 | | -import git |
9 | | -from aind_behavior_curriculum import Stage, TrainerState |
10 | | -from aind_behavior_services.rig import AindBehaviorRigModel |
11 | | -from aind_behavior_services.session import AindBehaviorSessionModel |
12 | | -from aind_behavior_services.task_logic import AindBehaviorTaskLogicModel |
13 | | -from pydantic import Field |
| 4 | + |
| 5 | +from _mocks import ( |
| 6 | + LIB_CONFIG, |
| 7 | + AindBehaviorSessionModel, |
| 8 | + DemoAindDataSchemaSessionDataMapper, |
| 9 | + RigModel, |
| 10 | + TaskLogicModel, |
| 11 | + create_fake_rig, |
| 12 | + create_fake_subjects, |
| 13 | +) |
14 | 14 | from pydantic_settings import CliApp |
15 | 15 |
|
16 | 16 | from clabe import resource_monitor |
17 | 17 | from clabe.apps import CurriculumApp, CurriculumSettings, PythonScriptApp |
18 | | -from clabe.data_mapper import DataMapper |
19 | 18 | from clabe.launcher import ( |
20 | 19 | Launcher, |
21 | 20 | LauncherCliArgs, |
|
24 | 23 |
|
25 | 24 | logger = logging.getLogger(__name__) |
26 | 25 |
|
27 | | -TASK_NAME = "RandomTask" |
28 | | -LIB_CONFIG = rf"local\AindBehavior.db\{TASK_NAME}" |
29 | | - |
30 | | - |
31 | | -### Task-specific definitions |
32 | | -class RigModel(AindBehaviorRigModel): |
33 | | - rig_name: str = Field(default="TestRig", description="Rig name") |
34 | | - version: Literal["0.0.0"] = "0.0.0" |
35 | | - |
36 | | - |
37 | | -class TaskLogicModel(AindBehaviorTaskLogicModel): |
38 | | - version: Literal["0.0.0"] = "0.0.0" |
39 | | - name: Literal[TASK_NAME] = TASK_NAME |
40 | | - |
41 | | - |
42 | | -mock_trainer_state = TrainerState[Any]( |
43 | | - curriculum=None, |
44 | | - is_on_curriculum=False, |
45 | | - stage=Stage(name="TestStage", task=TaskLogicModel(name=TASK_NAME, task_parameters={"foo": "bar"})), |
46 | | -) |
47 | | - |
48 | | - |
49 | | -class MockAindDataSchemaSession: |
50 | | - def __init__( |
51 | | - self, |
52 | | - computer_name: Optional[str] = None, |
53 | | - repository: Optional[Union[os.PathLike, git.Repo]] = None, |
54 | | - task_name: Optional[str] = None, |
55 | | - ): |
56 | | - self.computer_name = computer_name |
57 | | - self.repository = repository |
58 | | - self.task_name = task_name |
59 | | - |
60 | | - def __str__(self) -> str: |
61 | | - return f"MockAindDataSchemaSession(computer_name={self.computer_name}, repository={self.repository}, task_name={self.task_name})" |
62 | | - |
63 | | - |
64 | | -class DemoAindDataSchemaSessionDataMapper(DataMapper[MockAindDataSchemaSession]): |
65 | | - def __init__( |
66 | | - self, |
67 | | - rig_model: RigModel, |
68 | | - session_model: AindBehaviorSessionModel, |
69 | | - task_logic_model: TaskLogicModel, |
70 | | - repository: Union[os.PathLike, git.Repo], |
71 | | - script_path: os.PathLike, |
72 | | - session_end_time: Optional[datetime.datetime] = None, |
73 | | - output_parameters: Optional[Dict] = None, |
74 | | - ): |
75 | | - super().__init__() |
76 | | - self.session_model = session_model |
77 | | - self.rig_model = rig_model |
78 | | - self.task_logic_model = task_logic_model |
79 | | - self.repository = repository |
80 | | - self.script_path = script_path |
81 | | - self.session_end_time = session_end_time |
82 | | - self.output_parameters = output_parameters |
83 | | - self._mapped: Optional[MockAindDataSchemaSession] = None |
84 | | - |
85 | | - def map(self) -> MockAindDataSchemaSession: |
86 | | - self._mapped = MockAindDataSchemaSession( |
87 | | - computer_name=self.rig_model.computer_name, repository=self.repository, task_name=self.task_logic_model.name |
88 | | - ) |
89 | | - print("#" * 50) |
90 | | - print("THIS IS MAPPED DATA!") |
91 | | - print("#" * 50) |
92 | | - print(self._mapped) |
93 | | - return self._mapped |
94 | | - |
95 | 26 |
|
96 | 27 | async def experiment(launcher: Launcher) -> None: |
97 | 28 | monitor = resource_monitor.ResourceMonitor( |
@@ -145,23 +76,6 @@ def fmt(value: str) -> str: |
145 | 76 | return |
146 | 77 |
|
147 | 78 |
|
148 | | -def create_fake_subjects(): |
149 | | - subjects = ["00000", "123456"] |
150 | | - for subject in subjects: |
151 | | - os.makedirs(f"{LIB_CONFIG}/Subjects/{subject}", exist_ok=True) |
152 | | - with open(f"{LIB_CONFIG}/Subjects/{subject}/task_logic.json", "w", encoding="utf-8") as f: |
153 | | - f.write(TaskLogicModel(task_parameters={"subject": subject}).model_dump_json(indent=2)) |
154 | | - with open(f"{LIB_CONFIG}/Subjects/{subject}/trainer_state.json", "w", encoding="utf-8") as f: |
155 | | - f.write(mock_trainer_state.model_dump_json(indent=2)) |
156 | | - |
157 | | - |
158 | | -def create_fake_rig(): |
159 | | - computer_name = os.getenv("COMPUTERNAME") |
160 | | - os.makedirs(_dir := f"{LIB_CONFIG}/Rig/{computer_name}", exist_ok=True) |
161 | | - with open(f"{_dir}/rig1.json", "w", encoding="utf-8") as f: |
162 | | - f.write(RigModel().model_dump_json(indent=2)) |
163 | | - |
164 | | - |
165 | 79 | def main(): |
166 | 80 | create_fake_subjects() |
167 | 81 | create_fake_rig() |
|
0 commit comments