diff --git a/src/openai/cli/_api/_main.py b/src/openai/cli/_api/_main.py index b04a3e52a4..2d67304156 100644 --- a/src/openai/cli/_api/_main.py +++ b/src/openai/cli/_api/_main.py @@ -2,7 +2,7 @@ from argparse import ArgumentParser -from . import chat, audio, files, image, models, completions, fine_tuning +from . import chat, audio, files, image, models, completions, fine_tuning, chat_fine_tunes def register_commands(parser: ArgumentParser) -> None: @@ -15,3 +15,4 @@ def register_commands(parser: ArgumentParser) -> None: models.register(subparsers) completions.register(subparsers) fine_tuning.register(subparsers) + chat_fine_tunes.register(subparsers) diff --git a/src/openai/cli/_api/chat_fine_tunes.py b/src/openai/cli/_api/chat_fine_tunes.py new file mode 100644 index 0000000000..c0bcc94e85 --- /dev/null +++ b/src/openai/cli/_api/chat_fine_tunes.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from argparse import ArgumentParser + +from .chat_fine_tunes import jobs + +if TYPE_CHECKING: + from argparse import _SubParsersAction + + +def register(subparser: _SubParsersAction[ArgumentParser]) -> None: + jobs.register(subparser) \ No newline at end of file diff --git a/src/openai/cli/_api/chat_fine_tunes/__init__.py b/src/openai/cli/_api/chat_fine_tunes/__init__.py new file mode 100644 index 0000000000..351997ba16 --- /dev/null +++ b/src/openai/cli/_api/chat_fine_tunes/__init__.py @@ -0,0 +1 @@ +# API commands for chat fine-tuning (convenience aliases) \ No newline at end of file diff --git a/src/openai/cli/_api/chat_fine_tunes/jobs.py b/src/openai/cli/_api/chat_fine_tunes/jobs.py new file mode 100644 index 0000000000..452c30925c --- /dev/null +++ b/src/openai/cli/_api/chat_fine_tunes/jobs.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING +from argparse import ArgumentParser + +from ..._utils import get_client, print_model +from ...._types import Omittable, omit +from ...._utils import is_given +from ..._models import BaseModel +from ....pagination import SyncCursorPage +from ....types.fine_tuning import ( + FineTuningJob, + FineTuningJobEvent, +) + +if TYPE_CHECKING: + from argparse import _SubParsersAction + + +def register(subparser: _SubParsersAction[ArgumentParser]) -> None: + sub = subparser.add_parser("chat_fine_tunes.create") + sub.add_argument( + "-m", + "--model", + help="The model to fine-tune.", + required=True, + ) + sub.add_argument( + "-t", + "--training-file", + help="The training file to fine-tune the model on.", + required=True, + ) + sub.add_argument( + "-H", + "--hyperparameters", + help="JSON string of hyperparameters to use for fine-tuning.", + type=str, + ) + sub.add_argument( + "-s", + "--suffix", + help="A suffix to add to the fine-tuned model name.", + ) + sub.add_argument( + "-v", + "--validation-file", + help="The validation file to use for fine-tuning.", + ) + sub.set_defaults(func=CLIChatFineTunes.create, args_model=CLIChatFineTunesCreateArgs) + + sub = subparser.add_parser("chat_fine_tunes.get") + sub.add_argument( + "-i", + "--id", + help="The ID of the fine-tuning job to retrieve.", + required=True, + ) + sub.set_defaults(func=CLIChatFineTunes.retrieve, args_model=CLIChatFineTunesRetrieveArgs) + + sub = subparser.add_parser("chat_fine_tunes.list") + sub.add_argument( + "-a", + "--after", + help="Identifier for the last job from the previous pagination request. If provided, only jobs created after this job will be returned.", + ) + sub.add_argument( + "-l", + "--limit", + help="Number of fine-tuning jobs to retrieve.", + type=int, + ) + sub.set_defaults(func=CLIChatFineTunes.list, args_model=CLIChatFineTunesListArgs) + + sub = subparser.add_parser("chat_fine_tunes.cancel") + sub.add_argument( + "-i", + "--id", + help="The ID of the fine-tuning job to cancel.", + required=True, + ) + sub.set_defaults(func=CLIChatFineTunes.cancel, args_model=CLIChatFineTunesCancelArgs) + + +class CLIChatFineTunesCreateArgs(BaseModel): + model: str + training_file: str + hyperparameters: Omittable[str] = omit + suffix: Omittable[str] = omit + validation_file: Omittable[str] = omit + + +class CLIChatFineTunesRetrieveArgs(BaseModel): + id: str + + +class CLIChatFineTunesListArgs(BaseModel): + after: Omittable[str] = omit + limit: Omittable[int] = omit + + +class CLIChatFineTunesCancelArgs(BaseModel): + id: str + + +class CLIChatFineTunes: + @staticmethod + def create(args: CLIChatFineTunesCreateArgs) -> None: + hyperparameters = json.loads(str(args.hyperparameters)) if is_given(args.hyperparameters) else omit + fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.create( + model=args.model, + training_file=args.training_file, + hyperparameters=hyperparameters, + suffix=args.suffix, + validation_file=args.validation_file, + ) + print_model(fine_tuning_job) + + @staticmethod + def retrieve(args: CLIChatFineTunesRetrieveArgs) -> None: + fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.retrieve(fine_tuning_job_id=args.id) + print_model(fine_tuning_job) + + @staticmethod + def list(args: CLIChatFineTunesListArgs) -> None: + fine_tuning_jobs: SyncCursorPage[FineTuningJob] = get_client().fine_tuning.jobs.list( + after=args.after or omit, limit=args.limit or omit + ) + print_model(fine_tuning_jobs) + + @staticmethod + def cancel(args: CLIChatFineTunesCancelArgs) -> None: + fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.cancel(fine_tuning_job_id=args.id) + print_model(fine_tuning_job) \ No newline at end of file diff --git a/src/openai/cli/_tools/_main.py b/src/openai/cli/_tools/_main.py index bd6cda408f..f88c10b140 100644 --- a/src/openai/cli/_tools/_main.py +++ b/src/openai/cli/_tools/_main.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from argparse import ArgumentParser -from . import migrate, fine_tunes +from . import migrate, fine_tunes, chat_fine_tunes if TYPE_CHECKING: from argparse import _SubParsersAction @@ -15,3 +15,4 @@ def register_commands(parser: ArgumentParser, subparser: _SubParsersAction[Argum namespaced = parser.add_subparsers(title="Tools", help="Convenience client side tools") fine_tunes.register(namespaced) + chat_fine_tunes.register(namespaced) diff --git a/src/openai/cli/_tools/chat_fine_tunes.py b/src/openai/cli/_tools/chat_fine_tunes.py new file mode 100644 index 0000000000..fab57e7377 --- /dev/null +++ b/src/openai/cli/_tools/chat_fine_tunes.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING +from argparse import ArgumentParser + +from .._models import BaseModel +from ...lib._validators import ( + get_chat_validators, + write_out_file, + read_any_format, + apply_validators, + apply_necessary_remediation, +) + +if TYPE_CHECKING: + from argparse import _SubParsersAction + + +def register(subparser: _SubParsersAction[ArgumentParser]) -> None: + sub = subparser.add_parser("chat_fine_tunes.prepare_data") + sub.add_argument( + "-f", + "--file", + required=True, + help="JSONL, JSON file containing chat messages with roles (system, user, assistant) to be analyzed." + "This should be the local file path.", + ) + sub.add_argument( + "-q", + "--quiet", + required=False, + action="store_true", + help="Auto accepts all suggestions, without asking for user input. To be used within scripts.", + ) + sub.set_defaults(func=prepare_data, args_model=PrepareDataArgs) + + +class PrepareDataArgs(BaseModel): + file: str + + quiet: bool + + +def prepare_data(args: PrepareDataArgs) -> None: + sys.stdout.write("Analyzing chat fine-tuning data...\n") + fname = args.file + auto_accept = args.quiet + df, remediation = read_any_format(fname) + apply_necessary_remediation(None, remediation) + + validators = get_chat_validators() + + assert df is not None + + apply_validators( + df, + fname, + remediation, + validators, + auto_accept, + write_out_file_func=write_out_file, + ) \ No newline at end of file diff --git a/src/openai/lib/_validators.py b/src/openai/lib/_validators.py index cf24cd2294..700fb69084 100644 --- a/src/openai/lib/_validators.py +++ b/src/openai/lib/_validators.py @@ -747,6 +747,56 @@ def get_common_xfix(series: Any, xfix: str = "suffix") -> str: Validator: TypeAlias = "Callable[[pd.DataFrame], Remediation | None]" +def chat_messages_validator(df: pd.DataFrame) -> Remediation: + """ + This validator will ensure that the messages column contains properly formatted chat messages. + """ + import json + + immediate_msg = None + error_msg = None + + # Check if we have a messages column + if 'messages' not in df.columns: + error_msg = "`messages` column/key is missing. Chat fine-tuning requires a 'messages' column with an array of message objects." + else: + # Try to validate the format + try: + for idx, row in df.iterrows(): + if 'messages' in row: + messages = json.loads(row['messages']) if isinstance(row['messages'], str) else row['messages'] + if not isinstance(messages, list): + raise ValueError(f"Messages must be a list in row {idx}") + for msg in messages: + if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: + raise ValueError(f"Each message must have 'role' and 'content' in row {idx}") + if msg['role'] not in ['system', 'user', 'assistant']: + raise ValueError(f"Role must be 'system', 'user', or 'assistant' in row {idx}") + immediate_msg = f"\n- Your file contains {len(df)} chat conversations with properly formatted messages" + except (json.JSONDecodeError, ValueError, TypeError) as e: + error_msg = f"Invalid messages format: {str(e)}. Messages must be a JSON array of objects with 'role' and 'content' fields." + + return Remediation( + name="chat_messages", + immediate_msg=immediate_msg, + error_msg=error_msg, + ) + + +def chat_num_examples_validator(df: pd.DataFrame) -> Remediation: + """ + This validator will print out the number of chat examples and recommend increasing if less than 10. + """ + MIN_EXAMPLES = 10 + optional_suggestion = ( + "" + if len(df) >= MIN_EXAMPLES + else ". For chat fine-tuning, we recommend having at least 10 examples, but preferably 50-100 for better results" + ) + immediate_msg = f"\n- Your file contains {len(df)} chat conversations{optional_suggestion}" + return Remediation(name="chat_num_examples", immediate_msg=immediate_msg) + + def get_validators() -> list[Validator]: return [ num_examples_validator, @@ -767,6 +817,17 @@ def get_validators() -> list[Validator]: ] +def get_chat_validators() -> list[Validator]: + """ + Get validators specifically for chat fine-tuning data format. + """ + return [ + chat_num_examples_validator, + chat_messages_validator, + duplicated_rows_validator, + ] + + def apply_validators( df: pd.DataFrame, fname: str,