Skip to content

Commit eb10b16

Browse files
committed
Add CLI tools for chat fine-tuning
This implements the feature requested in issue #622 to provide CLI tools for chat fine-tuning that work with the new chat message format. Added functionality: - `openai tools chat_fine_tunes.prepare_data` for validating and preparing chat fine-tuning data in the messages format - `openai api chat_fine_tunes.*` commands (create, list, get, cancel) as convenience aliases for the existing fine_tuning.jobs commands - Chat-specific validators that check for proper message format with role and content fields - Support for system, user, and assistant roles The tools work with JSONL files containing messages arrays, which is the format required for chat fine-tuning with the new API. Fixes #622
1 parent 71dedfa commit eb10b16

File tree

7 files changed

+277
-2
lines changed

7 files changed

+277
-2
lines changed

src/openai/cli/_api/_main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from argparse import ArgumentParser
44

5-
from . import chat, audio, files, image, models, completions, fine_tuning
5+
from . import chat, audio, files, image, models, completions, fine_tuning, chat_fine_tunes
66

77

88
def register_commands(parser: ArgumentParser) -> None:
@@ -15,3 +15,4 @@ def register_commands(parser: ArgumentParser) -> None:
1515
models.register(subparsers)
1616
completions.register(subparsers)
1717
fine_tuning.register(subparsers)
18+
chat_fine_tunes.register(subparsers)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
from argparse import ArgumentParser
5+
6+
from .chat_fine_tunes import jobs
7+
8+
if TYPE_CHECKING:
9+
from argparse import _SubParsersAction
10+
11+
12+
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
13+
jobs.register(subparser)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# API commands for chat fine-tuning (convenience aliases)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from typing import TYPE_CHECKING
5+
from argparse import ArgumentParser
6+
7+
from ..._utils import get_client, print_model
8+
from ...._types import Omittable, omit
9+
from ...._utils import is_given
10+
from ..._models import BaseModel
11+
from ....pagination import SyncCursorPage
12+
from ....types.fine_tuning import (
13+
FineTuningJob,
14+
FineTuningJobEvent,
15+
)
16+
17+
if TYPE_CHECKING:
18+
from argparse import _SubParsersAction
19+
20+
21+
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
22+
sub = subparser.add_parser("chat_fine_tunes.create")
23+
sub.add_argument(
24+
"-m",
25+
"--model",
26+
help="The model to fine-tune.",
27+
required=True,
28+
)
29+
sub.add_argument(
30+
"-t",
31+
"--training-file",
32+
help="The training file to fine-tune the model on.",
33+
required=True,
34+
)
35+
sub.add_argument(
36+
"-H",
37+
"--hyperparameters",
38+
help="JSON string of hyperparameters to use for fine-tuning.",
39+
type=str,
40+
)
41+
sub.add_argument(
42+
"-s",
43+
"--suffix",
44+
help="A suffix to add to the fine-tuned model name.",
45+
)
46+
sub.add_argument(
47+
"-v",
48+
"--validation-file",
49+
help="The validation file to use for fine-tuning.",
50+
)
51+
sub.set_defaults(func=CLIChatFineTunes.create, args_model=CLIChatFineTunesCreateArgs)
52+
53+
sub = subparser.add_parser("chat_fine_tunes.get")
54+
sub.add_argument(
55+
"-i",
56+
"--id",
57+
help="The ID of the fine-tuning job to retrieve.",
58+
required=True,
59+
)
60+
sub.set_defaults(func=CLIChatFineTunes.retrieve, args_model=CLIChatFineTunesRetrieveArgs)
61+
62+
sub = subparser.add_parser("chat_fine_tunes.list")
63+
sub.add_argument(
64+
"-a",
65+
"--after",
66+
help="Identifier for the last job from the previous pagination request. If provided, only jobs created after this job will be returned.",
67+
)
68+
sub.add_argument(
69+
"-l",
70+
"--limit",
71+
help="Number of fine-tuning jobs to retrieve.",
72+
type=int,
73+
)
74+
sub.set_defaults(func=CLIChatFineTunes.list, args_model=CLIChatFineTunesListArgs)
75+
76+
sub = subparser.add_parser("chat_fine_tunes.cancel")
77+
sub.add_argument(
78+
"-i",
79+
"--id",
80+
help="The ID of the fine-tuning job to cancel.",
81+
required=True,
82+
)
83+
sub.set_defaults(func=CLIChatFineTunes.cancel, args_model=CLIChatFineTunesCancelArgs)
84+
85+
86+
class CLIChatFineTunesCreateArgs(BaseModel):
87+
model: str
88+
training_file: str
89+
hyperparameters: Omittable[str] = omit
90+
suffix: Omittable[str] = omit
91+
validation_file: Omittable[str] = omit
92+
93+
94+
class CLIChatFineTunesRetrieveArgs(BaseModel):
95+
id: str
96+
97+
98+
class CLIChatFineTunesListArgs(BaseModel):
99+
after: Omittable[str] = omit
100+
limit: Omittable[int] = omit
101+
102+
103+
class CLIChatFineTunesCancelArgs(BaseModel):
104+
id: str
105+
106+
107+
class CLIChatFineTunes:
108+
@staticmethod
109+
def create(args: CLIChatFineTunesCreateArgs) -> None:
110+
hyperparameters = json.loads(str(args.hyperparameters)) if is_given(args.hyperparameters) else omit
111+
fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.create(
112+
model=args.model,
113+
training_file=args.training_file,
114+
hyperparameters=hyperparameters,
115+
suffix=args.suffix,
116+
validation_file=args.validation_file,
117+
)
118+
print_model(fine_tuning_job)
119+
120+
@staticmethod
121+
def retrieve(args: CLIChatFineTunesRetrieveArgs) -> None:
122+
fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.retrieve(fine_tuning_job_id=args.id)
123+
print_model(fine_tuning_job)
124+
125+
@staticmethod
126+
def list(args: CLIChatFineTunesListArgs) -> None:
127+
fine_tuning_jobs: SyncCursorPage[FineTuningJob] = get_client().fine_tuning.jobs.list(
128+
after=args.after or omit, limit=args.limit or omit
129+
)
130+
print_model(fine_tuning_jobs)
131+
132+
@staticmethod
133+
def cancel(args: CLIChatFineTunesCancelArgs) -> None:
134+
fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.cancel(fine_tuning_job_id=args.id)
135+
print_model(fine_tuning_job)

src/openai/cli/_tools/_main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING
44
from argparse import ArgumentParser
55

6-
from . import migrate, fine_tunes
6+
from . import migrate, fine_tunes, chat_fine_tunes
77

88
if TYPE_CHECKING:
99
from argparse import _SubParsersAction
@@ -15,3 +15,4 @@ def register_commands(parser: ArgumentParser, subparser: _SubParsersAction[Argum
1515
namespaced = parser.add_subparsers(title="Tools", help="Convenience client side tools")
1616

1717
fine_tunes.register(namespaced)
18+
chat_fine_tunes.register(namespaced)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from typing import TYPE_CHECKING
5+
from argparse import ArgumentParser
6+
7+
from .._models import BaseModel
8+
from ...lib._validators import (
9+
get_chat_validators,
10+
write_out_file,
11+
read_any_format,
12+
apply_validators,
13+
apply_necessary_remediation,
14+
)
15+
16+
if TYPE_CHECKING:
17+
from argparse import _SubParsersAction
18+
19+
20+
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
21+
sub = subparser.add_parser("chat_fine_tunes.prepare_data")
22+
sub.add_argument(
23+
"-f",
24+
"--file",
25+
required=True,
26+
help="JSONL, JSON file containing chat messages with roles (system, user, assistant) to be analyzed."
27+
"This should be the local file path.",
28+
)
29+
sub.add_argument(
30+
"-q",
31+
"--quiet",
32+
required=False,
33+
action="store_true",
34+
help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
35+
)
36+
sub.set_defaults(func=prepare_data, args_model=PrepareDataArgs)
37+
38+
39+
class PrepareDataArgs(BaseModel):
40+
file: str
41+
42+
quiet: bool
43+
44+
45+
def prepare_data(args: PrepareDataArgs) -> None:
46+
sys.stdout.write("Analyzing chat fine-tuning data...\n")
47+
fname = args.file
48+
auto_accept = args.quiet
49+
df, remediation = read_any_format(fname)
50+
apply_necessary_remediation(None, remediation)
51+
52+
validators = get_chat_validators()
53+
54+
assert df is not None
55+
56+
apply_validators(
57+
df,
58+
fname,
59+
remediation,
60+
validators,
61+
auto_accept,
62+
write_out_file_func=write_out_file,
63+
)

src/openai/lib/_validators.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,56 @@ def get_common_xfix(series: Any, xfix: str = "suffix") -> str:
747747
Validator: TypeAlias = "Callable[[pd.DataFrame], Remediation | None]"
748748

749749

750+
def chat_messages_validator(df: pd.DataFrame) -> Remediation:
751+
"""
752+
This validator will ensure that the messages column contains properly formatted chat messages.
753+
"""
754+
import json
755+
756+
immediate_msg = None
757+
error_msg = None
758+
759+
# Check if we have a messages column
760+
if 'messages' not in df.columns:
761+
error_msg = "`messages` column/key is missing. Chat fine-tuning requires a 'messages' column with an array of message objects."
762+
else:
763+
# Try to validate the format
764+
try:
765+
for idx, row in df.iterrows():
766+
if 'messages' in row:
767+
messages = json.loads(row['messages']) if isinstance(row['messages'], str) else row['messages']
768+
if not isinstance(messages, list):
769+
raise ValueError(f"Messages must be a list in row {idx}")
770+
for msg in messages:
771+
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
772+
raise ValueError(f"Each message must have 'role' and 'content' in row {idx}")
773+
if msg['role'] not in ['system', 'user', 'assistant']:
774+
raise ValueError(f"Role must be 'system', 'user', or 'assistant' in row {idx}")
775+
immediate_msg = f"\n- Your file contains {len(df)} chat conversations with properly formatted messages"
776+
except (json.JSONDecodeError, ValueError, TypeError) as e:
777+
error_msg = f"Invalid messages format: {str(e)}. Messages must be a JSON array of objects with 'role' and 'content' fields."
778+
779+
return Remediation(
780+
name="chat_messages",
781+
immediate_msg=immediate_msg,
782+
error_msg=error_msg,
783+
)
784+
785+
786+
def chat_num_examples_validator(df: pd.DataFrame) -> Remediation:
787+
"""
788+
This validator will print out the number of chat examples and recommend increasing if less than 10.
789+
"""
790+
MIN_EXAMPLES = 10
791+
optional_suggestion = (
792+
""
793+
if len(df) >= MIN_EXAMPLES
794+
else ". For chat fine-tuning, we recommend having at least 10 examples, but preferably 50-100 for better results"
795+
)
796+
immediate_msg = f"\n- Your file contains {len(df)} chat conversations{optional_suggestion}"
797+
return Remediation(name="chat_num_examples", immediate_msg=immediate_msg)
798+
799+
750800
def get_validators() -> list[Validator]:
751801
return [
752802
num_examples_validator,
@@ -767,6 +817,17 @@ def get_validators() -> list[Validator]:
767817
]
768818

769819

820+
def get_chat_validators() -> list[Validator]:
821+
"""
822+
Get validators specifically for chat fine-tuning data format.
823+
"""
824+
return [
825+
chat_num_examples_validator,
826+
chat_messages_validator,
827+
duplicated_rows_validator,
828+
]
829+
830+
770831
def apply_validators(
771832
df: pd.DataFrame,
772833
fname: str,

0 commit comments

Comments
 (0)