Skip to content

Commit e4765b9

Browse files
authored
Merge pull request #166 from derekhiggins/load-pipes
Load custom pipelines from shared data dir
2 parents 263372b + a78525e commit e4765b9

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ click>=8.1.7,<9.0.0
33
httpx>=0.25.0,<1.0.0
44
langchain-text-splitters
55
openai>=1.13.3,<2.0.0
6+
platformdirs>=4.2
67
# Note: this dependency goes along with langchain-text-splitters and mayt be
78
# removed once that one is removed.
89
# do not use 8.4.0 due to a bug in the library

src/instructlab/sdg/generate_data.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from datasets import Dataset
1515
import httpx
1616
import openai
17+
import platformdirs
1718

1819
# First Party
1920
# pylint: disable=ungrouped-imports
@@ -164,29 +165,48 @@ def _gen_test_data(
164165
outfile.write("\n")
165166

166167

168+
def _check_pipeline_dir(pipeline):
169+
for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]:
170+
if not os.path.exists(os.path.join(pipeline, file)):
171+
raise GenerateException(
172+
f"Error: pipeline directory ({pipeline}) does not contain {file}."
173+
)
174+
175+
167176
def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_generate):
168177
pipeline_pkg = None
169-
if pipeline == "full":
170-
pipeline_pkg = FULL_PIPELINES_PACKAGE
171-
elif pipeline == "simple":
172-
pipeline_pkg = SIMPLE_PIPELINES_PACKAGE
178+
179+
# Search for the pipeline in User and Site data directories
180+
# then for a package defined pipeline
181+
# and finally pipelines referenced by absolute path
182+
pd = platformdirs.PlatformDirs(
183+
appname=os.path.join("instructlab", "sdg"), multipath=True
184+
)
185+
for d in pd.iter_data_dirs():
186+
if os.path.exists(os.path.join(d, pipeline)):
187+
pipeline = os.path.join(d, pipeline)
188+
_check_pipeline_dir(pipeline)
189+
break
173190
else:
174-
# Validate that pipeline is a valid directory and that it contains the required files
175-
if not os.path.exists(pipeline):
176-
raise GenerateException(
177-
f"Error: pipeline directory ({pipeline}) does not exist."
178-
)
179-
for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]:
180-
if not os.path.exists(os.path.join(pipeline, file)):
191+
if pipeline == "full":
192+
pipeline_pkg = FULL_PIPELINES_PACKAGE
193+
elif pipeline == "simple":
194+
pipeline_pkg = SIMPLE_PIPELINES_PACKAGE
195+
else:
196+
# Validate that pipeline is a valid directory and that it contains the required files
197+
if not os.path.exists(pipeline):
181198
raise GenerateException(
182-
f"Error: pipeline directory ({pipeline}) does not contain {file}."
199+
f"Error: pipeline directory ({pipeline}) does not exist."
183200
)
201+
_check_pipeline_dir(pipeline)
184202

185203
ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate)
186204

187205
def load_pipeline(yaml_basename):
188206
if pipeline_pkg:
189-
with resources.path(pipeline_pkg, yaml_basename) as yaml_path:
207+
with resources.as_file(
208+
resources.files(pipeline_pkg).joinpath(yaml_basename)
209+
) as yaml_path:
190210
return Pipeline.from_file(ctx, yaml_path)
191211
else:
192212
return Pipeline.from_file(ctx, os.path.join(pipeline, yaml_basename))
@@ -236,7 +256,8 @@ def generate_data(
236256
use the SDG library constructs directly, and this function will likely be removed.
237257
238258
Args:
239-
pipeline: This argument may be either an alias defined by the sdg library ("simple", "full"),
259+
pipeline: This argument may be either an alias defined in a user or site "data directory"
260+
or an alias defined by the sdg library ("simple", "full")(if the data directory has no matches),
240261
or an absolute path to a directory containing the pipeline YAML files.
241262
We expect three files to be present in this directory: "knowledge.yaml",
242263
"freeform_skills.yaml", and "grounded_skills.yaml".

0 commit comments

Comments
 (0)