|
14 | 14 | from datasets import Dataset |
15 | 15 | import httpx |
16 | 16 | import openai |
| 17 | +import platformdirs |
17 | 18 |
|
18 | 19 | # First Party |
19 | 20 | # pylint: disable=ungrouped-imports |
@@ -164,29 +165,48 @@ def _gen_test_data( |
164 | 165 | outfile.write("\n") |
165 | 166 |
|
166 | 167 |
|
| 168 | +def _check_pipeline_dir(pipeline): |
| 169 | + for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]: |
| 170 | + if not os.path.exists(os.path.join(pipeline, file)): |
| 171 | + raise GenerateException( |
| 172 | + f"Error: pipeline directory ({pipeline}) does not contain {file}." |
| 173 | + ) |
| 174 | + |
| 175 | + |
167 | 176 | def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_generate): |
168 | 177 | pipeline_pkg = None |
169 | | - if pipeline == "full": |
170 | | - pipeline_pkg = FULL_PIPELINES_PACKAGE |
171 | | - elif pipeline == "simple": |
172 | | - pipeline_pkg = SIMPLE_PIPELINES_PACKAGE |
| 178 | + |
| 179 | + # Search for the pipeline in User and Site data directories |
| 180 | + # then for a package defined pipeline |
| 181 | + # and finally pipelines referenced by absolute path |
| 182 | + pd = platformdirs.PlatformDirs( |
| 183 | + appname=os.path.join("instructlab", "sdg"), multipath=True |
| 184 | + ) |
| 185 | + for d in pd.iter_data_dirs(): |
| 186 | + if os.path.exists(os.path.join(d, pipeline)): |
| 187 | + pipeline = os.path.join(d, pipeline) |
| 188 | + _check_pipeline_dir(pipeline) |
| 189 | + break |
173 | 190 | else: |
174 | | - # Validate that pipeline is a valid directory and that it contains the required files |
175 | | - if not os.path.exists(pipeline): |
176 | | - raise GenerateException( |
177 | | - f"Error: pipeline directory ({pipeline}) does not exist." |
178 | | - ) |
179 | | - for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]: |
180 | | - if not os.path.exists(os.path.join(pipeline, file)): |
| 191 | + if pipeline == "full": |
| 192 | + pipeline_pkg = FULL_PIPELINES_PACKAGE |
| 193 | + elif pipeline == "simple": |
| 194 | + pipeline_pkg = SIMPLE_PIPELINES_PACKAGE |
| 195 | + else: |
| 196 | + # Validate that pipeline is a valid directory and that it contains the required files |
| 197 | + if not os.path.exists(pipeline): |
181 | 198 | raise GenerateException( |
182 | | - f"Error: pipeline directory ({pipeline}) does not contain {file}." |
| 199 | + f"Error: pipeline directory ({pipeline}) does not exist." |
183 | 200 | ) |
| 201 | + _check_pipeline_dir(pipeline) |
184 | 202 |
|
185 | 203 | ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate) |
186 | 204 |
|
187 | 205 | def load_pipeline(yaml_basename): |
188 | 206 | if pipeline_pkg: |
189 | | - with resources.path(pipeline_pkg, yaml_basename) as yaml_path: |
| 207 | + with resources.as_file( |
| 208 | + resources.files(pipeline_pkg).joinpath(yaml_basename) |
| 209 | + ) as yaml_path: |
190 | 210 | return Pipeline.from_file(ctx, yaml_path) |
191 | 211 | else: |
192 | 212 | return Pipeline.from_file(ctx, os.path.join(pipeline, yaml_basename)) |
@@ -236,7 +256,8 @@ def generate_data( |
236 | 256 | use the SDG library constructs directly, and this function will likely be removed. |
237 | 257 |
|
238 | 258 | Args: |
239 | | - pipeline: This argument may be either an alias defined by the sdg library ("simple", "full"), |
| 259 | + pipeline: This argument may be either an alias defined in a user or site "data directory" |
| 260 | + or an alias defined by the sdg library ("simple", "full")(if the data directory has no matches), |
240 | 261 | or an absolute path to a directory containing the pipeline YAML files. |
241 | 262 | We expect three files to be present in this directory: "knowledge.yaml", |
242 | 263 | "freeform_skills.yaml", and "grounded_skills.yaml". |
|
0 commit comments