Skip to content

Commit 1998335

Browse files
committed
some bug fix in basic utterance generation
1 parent d628197 commit 1998335

File tree

4 files changed

+19
-13
lines changed

4 files changed

+19
-13
lines changed

autointent/_dataset/_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,10 @@ def to_json(self, filepath: str | Path) -> None:
144144
145145
:param filepath: The path to the file where the JSON data will be saved.
146146
"""
147-
with Path(filepath).open("w") as file:
147+
path = Path(filepath)
148+
if not path.parent.exists():
149+
path.parent.mkdir(parents=True)
150+
with path.open("w") as file:
148151
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)
149152

150153
def push_to_hub(self, repo_id: str, private: bool = False) -> None:

autointent/generation/utterances/basic/cli.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@ def main() -> None:
2828
default=None,
2929
help="Local path where to save result",
3030
)
31-
parser.add_argument(
32-
"--private",
33-
action="store_true",
34-
help="Publish privately if --output-repo option is used"
35-
)
31+
parser.add_argument("--private", action="store_true", help="Publish privately if --output-repo option is used")
3632
parser.add_argument(
3733
"--n-generations",
3834
type=int,
@@ -72,11 +68,18 @@ def main() -> None:
7268
args = parser.parse_args()
7369

7470
dataset = load_dataset(args.input_path)
75-
generator = UtteranceGenerator(Generator(), args.custom_instruction, args.length, args.style, args.same_punctuation)
76-
generator.augment(dataset, n_generations=args.n_generations, max_sample_utterances=args.n_sample_utterances)
71+
generator = UtteranceGenerator(
72+
Generator(), args.custom_instruction or [], args.length, args.style, args.same_punctuation
73+
)
74+
generator.augment(
75+
dataset, n_generations=args.n_generations, max_sample_utterances=args.n_sample_utterances
76+
)
77+
78+
dataset.to_json(args.output_path)
7779

7880
if args.output_repo is not None:
79-
dataset.push_to_hub(args.output_repo)
81+
dataset.push_to_hub(args.output_repo, private=args.private)
82+
8083

8184
if __name__ == "__main__":
8285
main()

autointent/generation/utterances/basic/utterance_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,17 @@ def augment(
8989
)
9090
if update_split:
9191
generated_split = HFDataset.from_list(new_samples)
92-
dataset[split_name] = concatenate_datasets(original_split, generated_split)
92+
dataset[split_name] = concatenate_datasets([original_split, generated_split])
9393
return [Sample(**sample) for sample in new_samples]
9494

9595

9696
def _load_prompt() -> str:
97-
with ires.files("autointent.generation.basic").joinpath("chat_template.yaml").open() as file:
97+
with ires.files("autointent.generation.utterances.basic").joinpath("chat_template.yaml").open() as file:
9898
return file.read()
9999

100100

101101
def _load_extra_instructions() -> dict[str, Any]:
102-
with ires.files("autointent.generation.basic").joinpath("extra_instructions.json").open() as file:
102+
with ires.files("autointent.generation.utterances.basic").joinpath("extra_instructions.json").open() as file:
103103
return json.load(file)
104104

105105

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Documentation = "https://deeppavlov.github.io/AutoIntent/"
5757
"autointent" = "autointent._pipeline._cli_endpoint:optimize"
5858
"autointent-inference" = "autointent.pipeline.inference.cli_endpoint:main"
5959
"clear-cache" = "autointent.context.vector_index_client.cache:clear_chroma_cache"
60-
60+
"basic-aug" = "autointent.generation.utterances.basic.cli:main"
6161

6262
[tool.poetry.group.dev]
6363
optional = true

0 commit comments

Comments
 (0)