diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..81773666c --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,4 @@ +coverage: + ignore: + - "examples/**" + - "internal/test/**" diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..e36c38239 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: [sashabaranov, vvatanabe] diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..536a2ee29 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,32 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +Your issue may already be reported! +Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. + +**Describe the bug** +A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s). + +**To Reproduce** +Steps to reproduce the behavior, including any relevant code snippets. + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots/Logs** +If applicable, add screenshots to help explain your problem. For non-graphical issues, please provide any relevant logs or stack traces. + +**Environment (please complete the following information):** + - go-openai version: [e.g. v1.12.0] + - Go version: [e.g. 1.18] + - OpenAI API version: [e.g. v1] + - OS: [e.g. Ubuntu 20.04] + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..2359e5c00 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,23 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +Your issue may already be reported! +Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..222c065ce --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,23 @@ +A similar PR may already be submitted! +Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one. + +If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly. + +Thanks for submitting a pull request! Please provide enough information so that others can review your pull request. + +**Describe the change** +Please provide a clear and concise description of the changes you're proposing. Explain what problem it solves or what feature it adds. + +**Provide OpenAI documentation link** +Provide a relevant API doc from https://platform.openai.com/docs/api-reference + +**Describe your solution** +Describe how your changes address the problem or how they add the feature. This should include a brief description of your approach and any new libraries or dependencies you're using. + +**Tests** +Briefly describe how you have tested these changes. If possible — please add integration tests. + +**Additional context** +Add any other context or screenshots or logs about your pull request here. If the pull request relates to an open issue, please link to it. + +Issue: #XXXX diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml new file mode 100644 index 000000000..32723c4e9 --- /dev/null +++ b/.github/workflows/close-inactive-issues.yml @@ -0,0 +1,23 @@ +name: Close inactive issues +on: + schedule: + - cron: "30 1 * * *" + +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v9 + with: + days-before-issue-stale: 30 + days-before-issue-close: 14 + stale-issue-label: "stale" + exempt-issue-labels: 'bug,enhancement' + stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." + close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + days-before-pr-stale: -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 000000000..7260b00b4 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,21 @@ +name: Integration tests + +on: + push: + branches: + - master + +jobs: + integration_tests: + name: Run integration tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Run integration tests + env: + OPENAI_TOKEN: ${{ secrets.OPENAI_TOKEN }} + run: go test -v -tags=integration ./api_integration_test.go diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 8df721f0f..2c9730656 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -9,19 +9,21 @@ jobs: name: Sanity check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: - go-version: '1.19' + go-version: '1.24' - name: Run vet run: | - go vet . + go vet -stdversion ./... - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v7 with: - version: latest + version: v2.1.5 - name: Run tests - run: go test -race -covermode=atomic -coverprofile=coverage.out -v . + run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index 99b40bf17..b0ac1605c 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,7 @@ # Auth token for tests .openai-token -.idea \ No newline at end of file +.idea + +# Generated by tests +test.mp3 \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml index 58fab4a20..6391ad76f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,272 +1,168 @@ -## Golden config for golangci-lint v1.47.3 -# -# This is the best config for golangci-lint based on my experience and opinion. -# It is very strict, but not extremely strict. -# Feel free to adopt and change it for your needs. - -run: - # Timeout for analysis, e.g. 30s, 5m. - # Default: 1m - timeout: 3m - - -# This file contains only configs which differ from defaults. -# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -linters-settings: - cyclop: - # The maximal code complexity to report. - # Default: 10 - max-complexity: 30 - # The maximal average package complexity. - # If it's higher than 0.0 (float) the check is enabled - # Default: 0.0 - package-average: 10.0 - - errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: true - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: 100 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: 50 - - gocognit: - # Minimal code complexity to report - # Default: 30 (but we recommend 10-20) - min-complexity: 20 - - gocritic: - # Settings passed to gocritic. - # The settings key is the name of a supported gocritic checker. - # The list of supported checkers can be find in https://go-critic.github.io/overview. - settings: - captLocal: - # Whether to restrict checker to params only. - # Default: true - paramsOnly: false - underef: - # Whether to skip (*x).method() calls where x is a pointer receiver. - # Default: true - skipRecvDeref: false - - gomnd: - # List of function patterns to exclude from analysis. - # Values always ignored: `time.Date` - # Default: [] - ignored-functions: - - os.Chmod - - os.Mkdir - - os.MkdirAll - - os.OpenFile - - os.WriteFile - - prometheus.ExponentialBuckets - - prometheus.ExponentialBucketsRange - - prometheus.LinearBuckets - - strconv.FormatFloat - - strconv.FormatInt - - strconv.FormatUint - - strconv.ParseFloat - - strconv.ParseInt - - strconv.ParseUint - - gomodguard: - blocked: - # List of blocked modules. - # Default: [] - modules: - - github.com/golang/protobuf: - recommendations: - - google.golang.org/protobuf - reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" - - github.com/satori/go.uuid: - recommendations: - - github.com/google/uuid - reason: "satori's package is not maintained" - - github.com/gofrs/uuid: - recommendations: - - github.com/google/uuid - reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" - - govet: - # Enable all analyzers. - # Default: false - enable-all: true - # Disable analyzers by name. - # Run `go tool vet help` to see all analyzers. - # Default: [] - disable: - - fieldalignment # too strict - # Settings per analyzer. - settings: - shadow: - # Whether to be strict about shadowing; can be noisy. - # Default: false - strict: true - - nakedret: - # Make an issue if func has more lines of code than this setting, and it has naked returns. - # Default: 30 - max-func-lines: 0 - - nolintlint: - # Exclude following linters from requiring an explanation. - # Default: [] - allow-no-explanation: [ funlen, gocognit, lll ] - # Enable to require an explanation of nonzero length after each nolint directive. - # Default: false - require-explanation: true - # Enable to require nolint directives to mention the specific linter being suppressed. - # Default: false - require-specific: true - - rowserrcheck: - # database/sql is always checked - # Default: [] - packages: - - github.com/jmoiron/sqlx - - tenv: - # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. - # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - # Default: false - all: true - - varcheck: - # Check usage of exported fields and variables. - # Default: false - exported-fields: false # default false # TODO: enable after fixing false positives - - +version: "2" linters: - disable-all: true + default: none enable: - ## enabled by default - - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - - gosimple # Linter for Go source code that specializes in simplifying a code - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - - ineffassign # Detects when assignments to existing variables are not used - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - - unused # Checks Go code for unused constants, variables, functions and types - ## disabled by default - # - asasalint # Check for pass []any as any in variadic func(...any) - - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - - bidichk # Checks for dangerous unicode character sequences - - bodyclose # checks whether HTTP response body is closed successfully - - contextcheck # check the function whether use a non-inherited context - - cyclop # checks function and package cyclomatic complexity - - dupl # Tool for code clone detection - - durationcheck # check for two durations multiplied together - - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - - exhaustive # check exhaustiveness of enum switch statements - - exportloopref # checks for pointers to enclosing loop variables - - forbidigo # Forbids identifiers - - funlen # Tool for detection of long functions - # - gochecknoglobals # check that no global variables exist - - gochecknoinits # Checks that no init functions are present in Go code - - gocognit # Computes and checks the cognitive complexity of functions - - goconst # Finds repeated strings that could be replaced by a constant - - gocritic # Provides diagnostics that check for bugs, performance and style issues. - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomnd # An analyzer to detect magic numbers. - - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - - goprintffuncname # Checks that printf-like functions are named with f at the end - - gosec # Inspects source code for security problems - - lll # Reports long lines - - makezero # Finds slice declarations with non-zero initial length - # - nakedret # Finds naked returns in functions greater than a specified function length - - nestif # Reports deeply nested if statements - - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. - # - noctx # noctx finds sending http request without context.Context - - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - - predeclared # find code that shadows one of Go's predeclared identifiers - - promlinter # Check Prometheus metrics naming via promlint - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - rowserrcheck # checks whether Err of rows is checked successfully - - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - - stylecheck # Stylecheck is a replacement for golint - - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - testpackage # linter that makes you use a separate _test package - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - - unconvert # Remove unnecessary type conversions - - unparam # Reports unused function parameters - - wastedassign # wastedassign finds wasted assignment statements. - - whitespace # Tool for detection of leading and trailing whitespace - ## you may want to enable - #- decorder # check declaration order and count of types, constants, variables and functions - #- exhaustruct # Checks if all structure fields are initialized - #- goheader # Checks is file header matches to pattern - #- ireturn # Accept Interfaces, Return Concrete Types - #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated - #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope - #- wrapcheck # Checks that errors returned from external packages are wrapped - ## disabled - #- containedctx # containedctx is a linter that detects struct contained context.Context field - #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages - #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. - #- forcetypeassert # [replaced by errcheck] finds forced type assertions - #- gci # Gci controls golang package import order and makes it always deterministic. - #- godox # Tool for detection of FIXME, TODO and other comment keywords - #- goerr113 # [too strict] Golang linter to check the errors handling expressions - #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. - #- grouper # An analyzer to analyze expression groups. - #- ifshort # Checks that your code uses short syntax for if-statements whenever possible - #- importas # Enforces consistent import aliases - #- maintidx # maintidx measures the maintainability index of each function. - #- misspell # [useless] Finds commonly misspelled English words in comments - #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity - #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 - #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test - #- tagliatelle # Checks the struct tags. - #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - ## deprecated - #- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized - #- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes - #- interfacer # [deprecated] Linter that suggests narrower interface types - #- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted - #- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs - - + - asciicheck + - bidichk + - bodyclose + - contextcheck + - cyclop + - dupl + - durationcheck + - errcheck + - errname + - errorlint + - exhaustive + - forbidigo + - funlen + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godot + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - govet + - ineffassign + - lll + - makezero + - mnd + - nestif + - nilerr + - nilnil + - nolintlint + - nosprintfhostport + - predeclared + - promlinter + - revive + - rowserrcheck + - sqlclosecheck + - staticcheck + - testpackage + - tparallel + - unconvert + - unparam + - unused + - usetesting + - wastedassign + - whitespace + settings: + cyclop: + max-complexity: 30 + package-average: 10 + errcheck: + check-type-assertions: true + funlen: + lines: 100 + statements: 50 + gocognit: + min-complexity: 20 + gocritic: + settings: + captLocal: + paramsOnly: false + underef: + skipRecvDeref: false + gomodguard: + blocked: + modules: + - github.com/golang/protobuf: + recommendations: + - google.golang.org/protobuf + reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules + - github.com/satori/go.uuid: + recommendations: + - github.com/google/uuid + reason: satori's package is not maintained + - github.com/gofrs/uuid: + recommendations: + - github.com/google/uuid + reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw' + govet: + disable: + - fieldalignment + enable-all: true + settings: + shadow: + strict: true + mnd: + ignored-functions: + - os.Chmod + - os.Mkdir + - os.MkdirAll + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets + - prometheus.ExponentialBucketsRange + - prometheus.LinearBuckets + - strconv.FormatFloat + - strconv.FormatInt + - strconv.FormatUint + - strconv.ParseFloat + - strconv.ParseInt + - strconv.ParseUint + nakedret: + max-func-lines: 0 + nolintlint: + require-explanation: true + require-specific: true + allow-no-explanation: + - funlen + - gocognit + - lll + rowserrcheck: + packages: + - github.com/jmoiron/sqlx + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - linters: + - forbidigo + - mnd + - revive + path : ^examples/.*\.go$ + - linters: + - lll + source: ^//\s*go:generate\s + - linters: + - godot + source: (noinspection|TODO) + - linters: + - gocritic + source: //noinspection + - linters: + - errorlint + source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok { + - linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck + - staticcheck + path: _test\.go + paths: + - third_party$ + - builtin$ + - examples$ issues: - # Maximum count of issues with the same text. - # Set to 0 to disable. - # Default: 3 max-same-issues: 50 - - exclude-rules: - - source: "^//\\s*go:generate\\s" - linters: [ lll ] - - source: "(noinspection|TODO)" - linters: [ godot ] - - source: "//noinspection" - linters: [ gocritic ] - - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" - linters: [ errorlint ] - - path: "_test\\.go" - linters: - - bodyclose - - dupl - - funlen - - goconst - - gosec - - noctx - - wrapcheck +formatters: + enable: + - goimports + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..4dd184042 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,88 @@ +# Contributing Guidelines + +## Overview +Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://github.com/sashabaranov/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests. + +## Reporting Bugs +If you discover a bug, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it. + +## Suggesting Features +If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion. + +## Reporting Vulnerabilities +If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published. + +## Questions for Users +If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://github.com/sashabaranov/go-openai/discussions). + +## Contributing Code +There might already be a similar pull requests submitted! Please search for [pull requests](https://github.com/sashabaranov/go-openai/pulls) before creating one. + +### Requirements for Merging a Pull Request + +The requirements to accept a pull request are as follows: + +- Features not provided by the OpenAI API will not be accepted. +- The functionality of the feature must match that of the official OpenAI API. +- All pull requests should be written in Go according to common conventions, formatted with `goimports`, and free of warnings from tools like `golangci-lint`. +- Include tests and ensure all tests pass. +- Maintain test coverage without any reduction. +- All pull requests require approval from at least one Go OpenAI maintainer. + +**Note:** +The merging method for pull requests in this repository is squash merge. + +### Creating a Pull Request +- Fork the repository. +- Create a new branch and commit your changes. +- Push that branch to GitHub. +- Start a new Pull Request on GitHub. (Please use the pull request template to provide detailed information.) + +**Note:** +If your changes introduce breaking changes, please prefix your pull request title with "[BREAKING_CHANGES]". + +### Code Style +In this project, we adhere to the standard coding style of Go. Your code should maintain consistency with the rest of the codebase. To achieve this, please format your code using tools like `goimports` and resolve any syntax or style issues with `golangci-lint`. + +**Run goimports:** +``` +go install golang.org/x/tools/cmd/goimports@latest +``` + +``` +goimports -w . +``` + +**Run golangci-lint:** +``` +go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest +``` + +``` +golangci-lint run --out-format=github-actions +``` + +### Unit Test +Please create or update tests relevant to your changes. Ensure all tests run successfully to verify that your modifications do not adversely affect other functionalities. + +**Run test:** +``` +go test -v ./... +``` + +### Integration Test +Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. + +**Notes:** +These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. + +**Run integration test:** +``` +OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go +``` + +If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. + +--- + +We wholeheartedly welcome your active participation. Let's build an amazing project together! diff --git a/Makefile b/Makefile deleted file mode 100644 index 2e608aa0c..000000000 --- a/Makefile +++ /dev/null @@ -1,35 +0,0 @@ -##@ General - -# The help target prints out all targets with their descriptions organized -# beneath their categories. The categories are represented by '##@' and the -# target descriptions by '##'. The awk commands is responsible for reading the -# entire set of makefiles included in this invocation, looking for lines of the -# file as xyz: ## something, and then pretty-format the target and help. Then, -# if there's a line with ##@ something, that gets pretty-printed as a category. -# More info on the usage of ANSI control characters for terminal formatting: -# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters -# More info on the awk command: -# http://linuxcommand.org/lc3_adv_awk.php - -.PHONY: help -help: ## Display this help. - @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) - - -##@ Development - -.PHONY: test -TEST_ARGS ?= -v -TEST_TARGETS ?= ./... -test: ## Test the Go modules within this package. - @ echo ▶️ go test $(TEST_ARGS) $(TEST_TARGETS) - go test $(TEST_ARGS) $(TEST_TARGETS) - @ echo ✅ success! - - -.PHONY: lint -LINT_TARGETS ?= ./... -lint: ## Lint Go code with the installed golangci-lint - @ echo "▶️ golangci-lint run" - golangci-lint run $(LINT_TARGETS) - @ echo "✅ golangci-lint run" diff --git a/README.md b/README.md index 272d853c6..77b85e519 100644 --- a/README.md +++ b/README.md @@ -3,18 +3,22 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) [![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) -This library provides Go clients for [OpenAI API](https://platform.openai.com/). We support: +This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support: -* ChatGPT +* ChatGPT 4o, o1 * GPT-3, GPT-4 -* DALL·E 2 +* DALL·E 2, DALL·E 3, GPT Image 1 * Whisper -### Installation: +## Installation + ``` go get github.com/sashabaranov/go-openai ``` +Currently, go-openai requires Go version 1.18 or greater. + +## Usage ### ChatGPT example usage: @@ -52,6 +56,17 @@ func main() { ``` +### Getting an OpenAI API Key: + +1. Visit the OpenAI website at [https://platform.openai.com/account/api-keys](https://platform.openai.com/account/api-keys). +2. If you don't have an account, click on "Sign Up" to create one. If you do, click "Log In". +3. Once logged in, navigate to your API key management page. +4. Click on "Create new secret key". +5. Enter a name for your new key, then click "Create secret key". +6. Your new API key will be displayed. Use this key to interact with the OpenAI API. + +**Note:** Your API key is sensitive information. Do not share it with anyone. + ### Other examples:
@@ -126,7 +141,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", } @@ -159,7 +174,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, @@ -342,6 +357,66 @@ func main() { ```
+
+GPT Image 1 image generation + +```go +package main + +import ( + "context" + "encoding/base64" + "fmt" + "os" + + openai "github.com/sashabaranov/go-openai" +) + +func main() { + c := openai.NewClient("your token") + ctx := context.Background() + + req := openai.ImageRequest{ + Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.", + Background: openai.CreateImageBackgroundOpaque, + Model: openai.CreateImageModelGptImage1, + Size: openai.CreateImageSize1024x1024, + N: 1, + Quality: openai.CreateImageQualityLow, + OutputCompression: 100, + OutputFormat: openai.CreateImageOutputFormatJPEG, + // Moderation: openai.CreateImageModerationLow, + // User: "", + } + + resp, err := c.CreateImage(ctx, req) + if err != nil { + fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err) + return + } + + fmt.Println("Image Base64:", resp.Data[0].B64JSON) + + // Decode the base64 data + imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON) + if err != nil { + fmt.Printf("Base64 decode error: %v\n", err) + return + } + + // Write image to file + outputPath := "generated_image.jpg" + err = os.WriteFile(outputPath, imgBytes, 0644) + if err != nil { + fmt.Printf("Failed to write image file: %v\n", err) + return + } + + fmt.Printf("The image was saved as %s\n", outputPath) +} +``` +
+
Configuring proxy @@ -435,8 +510,15 @@ import ( ) func main() { + config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint") + // If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function + // config.AzureModelMapperFunc = func(model string) string { + // azureModelMapping := map[string]string{ + // "gpt-3.5-turbo": "your gpt-3.5-turbo deployment name", + // } + // return azureModelMapping[model] + // } - config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint ", "your Model deployment name") client := openai.NewClientWithConfig(config) resp, err := client.CreateChatCompletion( context.Background(), @@ -450,7 +532,6 @@ func main() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return @@ -458,9 +539,174 @@ func main() { fmt.Println(resp.Choices[0].Message.Content) } + +``` +
+ +
+Embedding Semantic Similarity + +```go +package main + +import ( + "context" + "log" + openai "github.com/sashabaranov/go-openai" + +) + +func main() { + client := openai.NewClient("your-token") + + // Create an EmbeddingRequest for the user query + queryReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck"}, + Model: openai.AdaEmbeddingV2, + } + + // Create an embedding for the user query + queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq) + if err != nil { + log.Fatal("Error creating query embedding:", err) + } + + // Create an EmbeddingRequest for the target text + targetReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"}, + Model: openai.AdaEmbeddingV2, + } + + // Create an embedding for the target text + targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq) + if err != nil { + log.Fatal("Error creating target embedding:", err) + } + + // Now that we have the embeddings for the user query and the target text, we + // can calculate their similarity. + queryEmbedding := queryResponse.Data[0] + targetEmbedding := targetResponse.Data[0] + + similarity, err := queryEmbedding.DotProduct(&targetEmbedding) + if err != nil { + log.Fatal("Error calculating dot product:", err) + } + + log.Printf("The similarity score between the query and the target is %f", similarity) +} + ```
+
+Azure OpenAI Embeddings + +```go +package main + +import ( + "context" + "fmt" + + openai "github.com/sashabaranov/go-openai" +) + +func main() { + + config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint") + config.APIVersion = "2023-05-15" // optional update to latest API version + + //If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function + //config.AzureModelMapperFunc = func(model string) string { + // azureModelMapping := map[string]string{ + // "gpt-3.5-turbo":"your gpt-3.5-turbo deployment name", + // } + // return azureModelMapping[model] + //} + + input := "Text to vectorize" + + client := openai.NewClientWithConfig(config) + resp, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Input: []string{input}, + Model: openai.AdaEmbeddingV2, + }) + + if err != nil { + fmt.Printf("CreateEmbeddings error: %v\n", err) + return + } + + vectors := resp.Data[0].Embedding // []float32 with 1536 dimensions + + fmt.Println(vectors[:10], "...", vectors[len(vectors)-10:]) +} +``` +
+ +
+JSON Schema for function calling + +It is now possible for chat completion to choose to call a function for more information ([see developer docs here](https://platform.openai.com/docs/guides/gpt/function-calling)). + +In order to describe the type of functions that can be called, a JSON schema must be provided. Many JSON schema libraries exist and are more advanced than what we can offer in this library, however we have included a simple `jsonschema` package for those who want to use this feature without formatting their own JSON schema payload. + +The developer documents give this JSON schema definition as an example: + +```json +{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and state, e.g. San Francisco, CA" + }, + "unit":{ + "type":"string", + "enum":[ + "celsius", + "fahrenheit" + ] + } + }, + "required":[ + "location" + ] + } +} +``` + +Using the `jsonschema` package, this schema could be created using structs as such: + +```go +FunctionDefinition{ + Name: "get_current_weather", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, +} +``` + +The `Parameters` field of a `FunctionDefinition` can accept either of the above styles, or even a nested struct from another library (as long as it can be marshalled into JSON). +
+
Error handling @@ -485,5 +731,183 @@ if errors.As(err, &e) { ```
+
+Fine Tune Model + +```go +package main + +import ( + "context" + "fmt" + "github.com/sashabaranov/go-openai" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + // create a .jsonl file with your training data for conversational model + // {"prompt": "", "completion": ""} + // {"prompt": "", "completion": ""} + // {"prompt": "", "completion": ""} + + // chat models are trained using the following file format: + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} + + // you can use openai cli tool to validate the data + // For more info - https://platform.openai.com/docs/guides/fine-tuning + + file, err := client.CreateFile(ctx, openai.FileRequest{ + FilePath: "training_prepared.jsonl", + Purpose: "fine-tune", + }) + if err != nil { + fmt.Printf("Upload JSONL file error: %v\n", err) + return + } + + // create a fine tuning job + // Streams events until the job is done (this often takes minutes, but can take hours if there are many jobs in the queue or your dataset is large) + // use below get method to know the status of your model + fineTuningJob, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{ + TrainingFile: file.ID, + Model: "davinci-002", // gpt-3.5-turbo-0613, babbage-002. + }) + if err != nil { + fmt.Printf("Creating new fine tune model error: %v\n", err) + return + } + + fineTuningJob, err = client.RetrieveFineTuningJob(ctx, fineTuningJob.ID) + if err != nil { + fmt.Printf("Getting fine tune model error: %v\n", err) + return + } + fmt.Println(fineTuningJob.FineTunedModel) + + // once the status of fineTuningJob is `succeeded`, you can use your fine tune model in Completion Request or Chat Completion Request + + // resp, err := client.CreateCompletion(ctx, openai.CompletionRequest{ + // Model: fineTuningJob.FineTunedModel, + // Prompt: "your prompt", + // }) + // if err != nil { + // fmt.Printf("Create completion error %v\n", err) + // return + // } + // + // fmt.Println(resp.Choices[0].Text) +} +``` +
+ +
+Structured Outputs + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + type Result struct { + Steps []struct { + Explanation string `json:"explanation"` + Output string `json:"output"` + } `json:"steps"` + FinalAnswer string `json:"final_answer"` + } + var result Result + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + log.Fatalf("GenerateSchemaForType error: %v", err) + } + resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "You are a helpful math tutor. Guide the user through the solution step by step.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "how can I solve 8x + 7 = -23", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "math_reasoning", + Schema: schema, + Strict: true, + }, + }, + }) + if err != nil { + log.Fatalf("CreateChatCompletion error: %v", err) + } + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + if err != nil { + log.Fatalf("Unmarshal schema error: %v", err) + } + fmt.Println(result) +} +``` +
See the `examples/` folder for more. +## Frequently Asked Questions + +### Why don't we get the same answer when specifying a temperature field of 0 and asking the same question? + +Even when specifying a temperature field of 0, it doesn't guarantee that you'll always get the same response. Several factors come into play. + +1. Go OpenAI Behavior: When you specify a temperature field of 0 in Go OpenAI, the omitempty tag causes that field to be removed from the request. Consequently, the OpenAI API applies the default value of 1. +2. Token Count for Input/Output: If there's a large number of tokens in the input and output, setting the temperature to 0 can still result in non-deterministic behavior. In particular, when using around 32k tokens, the likelihood of non-deterministic behavior becomes highest even with a temperature of 0. + +Due to the factors mentioned above, different answers may be returned even for the same question. + +**Workarounds:** +1. As of November 2023, use [the new `seed` parameter](https://platform.openai.com/docs/guides/text-generation/reproducible-outputs) in conjunction with the `system_fingerprint` response field, alongside Temperature management. +2. Try using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. +3. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. + +By adopting these strategies, you can expect more consistent results. + +**Related Issues:** +[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9) + +### Does Go OpenAI provide a method to count tokens? + +No, Go OpenAI does not offer a feature to count tokens, and there are no plans to provide such a feature in the future. However, if there's a way to implement a token counting feature with zero dependencies, it might be possible to merge that feature into Go OpenAI. Otherwise, it would be more appropriate to implement it in a dedicated library or repository. + +For counting tokens, you might find the following links helpful: +- [Counting Tokens For Chat API Calls](https://github.com/pkoukk/tiktoken-go#counting-tokens-for-chat-api-calls) +- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) + +**Related Issues:** +[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62) + +## Contributing + +By following [Contributing Guidelines](https://github.com/sashabaranov/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently. + +## Thank you + +We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project: +- [Carson Kahn](https://carsonkahn.com) of [Spindle AI](https://spindleai.com) + +To all of you: thank you. You've helped us achieve more than we ever imagined possible. Can't wait to see where we go next, together! diff --git a/api_integration_test.go b/api_integration_test.go new file mode 100644 index 000000000..7828d9451 --- /dev/null +++ b/api_integration_test.go @@ -0,0 +1,314 @@ +//go:build integration + +package openai_test + +import ( + "context" + "encoding/json" + "errors" + "io" + "os" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func TestAPI(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + _, err = c.ListEngines(ctx) + checks.NoError(t, err, "ListEngines error") + + _, err = c.GetEngine(ctx, openai.GPT3Davinci002) + checks.NoError(t, err, "GetEngine error") + + fileRes, err := c.ListFiles(ctx) + checks.NoError(t, err, "ListFiles error") + + if len(fileRes.Files) > 0 { + _, err = c.GetFile(ctx, fileRes.Files[0].ID) + checks.NoError(t, err, "GetFile error") + } // else skip + + embeddingReq := openai.EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: openai.AdaEmbeddingV2, + } + _, err = c.CreateEmbeddings(ctx, embeddingReq) + checks.NoError(t, err, "Embedding error") + + _, err = c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + + checks.NoError(t, err, "CreateChatCompletion (without name) returned error") + + _, err = c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Name: "John_Doe", + Content: "Hello!", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (with name) returned error") + + _, err = c.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "What is the weather like in Boston?", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "get_current_weather", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }}, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") +} + +func TestCompletionStream(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + c := openai.NewClient(apiToken) + ctx := context.Background() + + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: openai.GPT3Babbage002, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + +func TestAPIError(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken + "_invalid") + ctx := context.Background() + _, err = c.ListEngines(ctx) + checks.HasError(t, err, "ListEngines should fail with an invalid key") + + var apiErr *openai.APIError + if !errors.As(err, &apiErr) { + t.Fatalf("Error is not an APIError: %+v", err) + } + + if apiErr.HTTPStatusCode != 401 { + t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) + } + + switch v := apiErr.Code.(type) { + case string: + if v != "invalid_api_key" { + t.Fatalf("Unexpected API error code: %s", v) + } + default: + t.Fatalf("Unexpected API error code type: %T", v) + } + + if apiErr.Error() == "" { + t.Fatal("Empty error message occurred") + } +} + +func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + var result MyStructuredResponse + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") + } + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "cases", + Schema: schema, + Strict: true, + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") + if err == nil { + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") + } +} + +func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: "display_cases", + Strict: true, + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": { + Type: jsonschema.String, + }, + "CamelCase": { + Type: jsonschema.String, + }, + "KebabCase": { + Type: jsonschema.String, + }, + "SnakeCase": { + Type: jsonschema.String, + }, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + }, + }, + }, + ToolChoice: openai.ToolChoice{ + Type: openai.ToolTypeFunction, + Function: openai.ToolFunction{ + Name: "display_cases", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error") + var result = make(map[string]string) + err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error") + for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { + if _, ok := result[key]; !ok { + t.Errorf("key:%s does not exist.", key) + } + } +} diff --git a/api_internal_test.go b/api_internal_test.go index 9651ad402..09677968a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) { az.OrgID = c.OrgID cli := NewClientWithConfig(az) - req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil) + req, err := cli.newRequest(context.Background(), "POST", "/chat/completions") if err != nil { t.Errorf("Failed to create request: %v", err) } @@ -109,35 +109,88 @@ func TestRequestAuthHeader(t *testing.T) { func TestAzureFullURL(t *testing.T) { cases := []struct { - Name string - BaseURL string - Engine string - Expect string + Name string + BaseURL string + AzureModelMapper map[string]string + Suffix string + Model string + Expect string }{ { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", + nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + - "/chat/completions?api-version=2023-03-15-preview", + "/chat/completions?api-version=2023-05-15", }, { "AzureBaseURLWithoutSlashOK", "https://httpbin.org", + nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + - "/chat/completions?api-version=2023-03-15-preview", + "/chat/completions?api-version=2023-05-15", + }, + { + "", + "https://httpbin.org", + nil, + "/assistants?limit=10", + "chatgpt-demo", + "https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10", }, } for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine) + az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL(c.Suffix, withModel(c.Model)) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} + +func TestCloudflareAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Suffix string + Expect string + }{ + { + "CloudflareAzureBaseURLWithSlashAutoStrip", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "/chat/completions", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + { + "", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", + "/assistants?limit=10", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" + + "/assistants?api-version=2023-05-15&limit=10", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL) + az.APIType = APITypeCloudflareAzure + + cli := NewClientWithConfig(az) + + actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/api_test.go b/api_test.go deleted file mode 100644 index 78fd5cc6d..000000000 --- a/api_test.go +++ /dev/null @@ -1,264 +0,0 @@ -package openai_test - -import ( - "context" - "encoding/json" - "errors" - "io" - "net/http" - "net/http/httptest" - "os" - "testing" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" -) - -func TestAPI(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken) - ctx := context.Background() - _, err = c.ListEngines(ctx) - checks.NoError(t, err, "ListEngines error") - - _, err = c.GetEngine(ctx, "davinci") - checks.NoError(t, err, "GetEngine error") - - fileRes, err := c.ListFiles(ctx) - checks.NoError(t, err, "ListFiles error") - - if len(fileRes.Files) > 0 { - _, err = c.GetFile(ctx, fileRes.Files[0].ID) - checks.NoError(t, err, "GetFile error") - } // else skip - - embeddingReq := EmbeddingRequest{ - Input: []string{ - "The food was delicious and the waiter", - "Other examples of embedding request", - }, - Model: AdaSearchQuery, - } - _, err = c.CreateEmbeddings(ctx, embeddingReq) - checks.NoError(t, err, "Embedding error") - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Content: "Hello!", - }, - }, - }, - ) - - checks.NoError(t, err, "CreateChatCompletion (without name) returned error") - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Name: "John_Doe", - Content: "Hello!", - }, - }, - }, - ) - checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - - stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } -} - -func TestAPIError(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken + "_invalid") - ctx := context.Background() - _, err = c.ListEngines(ctx) - checks.HasError(t, err, "ListEngines should fail with an invalid key") - - var apiErr *APIError - if !errors.As(err, &apiErr) { - t.Fatalf("Error is not an APIError: %+v", err) - } - - if apiErr.HTTPStatusCode != 401 { - t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) - } - - switch v := apiErr.Code.(type) { - case string: - if v != "invalid_api_key" { - t.Fatalf("Unexpected API error code: %s", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } - - if apiErr.Error() == "" { - t.Fatal("Empty error message occurred") - } -} - -func TestAPIErrorUnmarshalJSONInteger(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case int: - if v != 418 { - t.Fatalf("Unexpected API code integer: %d; expected 418", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalJSONString(t *testing.T) { - var apiErr APIError - response := `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case string: - if v != "teapot" { - t.Fatalf("Unexpected API code string: %s; expected `teapot`", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalJSONNoCode(t *testing.T) { - // test integer code - response := `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - var apiErr APIError - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case nil: - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalInvalidData(t *testing.T) { - apiErr := APIError{} - data := []byte(`--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`) - err := apiErr.UnmarshalJSON(data) - checks.HasError(t, err, "Expected error when unmarshaling invalid data") - - if apiErr.Code != nil { - t.Fatalf("Expected nil code, got %q", apiErr.Code) - } - if apiErr.Message != "" { - t.Fatalf("Expected empty message, got %q", apiErr.Message) - } - if apiErr.Param != nil { - t.Fatalf("Expected nil param, got %q", *apiErr.Param) - } - if apiErr.Type != "" { - t.Fatalf("Expected empty type, got %q", apiErr.Type) - } -} - -func TestAPIErrorUnmarshalJSONInvalidParam(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Param should be a string") -} - -func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Type should be a string") -} - -func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Message should be a string") -} - -func TestRequestError(t *testing.T) { - var err error - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusTeapot) - })) - defer ts.Close() - - config := DefaultConfig("dummy") - config.BaseURL = ts.URL - c := NewClientWithConfig(config) - ctx := context.Background() - _, err = c.ListEngines(ctx) - checks.HasError(t, err, "ListEngines did not fail") - - var reqErr *RequestError - if !errors.As(err, &reqErr) { - t.Fatalf("Error is not a RequestError: %+v", err) - } - - if reqErr.HTTPStatusCode != 418 { - t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode) - } - - if reqErr.Unwrap() == nil { - t.Fatalf("Empty request error occurred") - } -} - -// numTokens Returns the number of GPT-3 encoded tokens in the given text. -// This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer -// -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) -func numTokens(s string) int { - return int(float32(len(s)) / 4) -} diff --git a/assistant.go b/assistant.go new file mode 100644 index 000000000..8aab5bcf0 --- /dev/null +++ b/assistant.go @@ -0,0 +1,325 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +const ( + assistantsSuffix = "/assistants" + assistantsFilesSuffix = "/files" +) + +type Assistant struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` // Deprecated in v2 + Metadata map[string]any `json:"metadata,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + + httpHeader +} + +type AssistantToolType string + +const ( + AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" + AssistantToolTypeRetrieval AssistantToolType = "retrieval" + AssistantToolTypeFunction AssistantToolType = "function" + AssistantToolTypeFileSearch AssistantToolType = "file_search" +) + +type AssistantTool struct { + Type AssistantToolType `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` +} + +type AssistantToolFileSearch struct { + VectorStoreIDs []string `json:"vector_store_ids"` +} + +type AssistantToolCodeInterpreter struct { + FileIDs []string `json:"file_ids"` +} + +type AssistantToolResource struct { + FileSearch *AssistantToolFileSearch `json:"file_search,omitempty"` + CodeInterpreter *AssistantToolCodeInterpreter `json:"code_interpreter,omitempty"` +} + +// AssistantRequest provides the assistant request parameters. +// When modifying the tools the API functions as the following: +// If Tools is undefined, no changes are made to the Assistant's tools. +// If Tools is empty slice it will effectively delete all of the Assistant's tools. +// If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. +type AssistantRequest struct { + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` +} + +// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases +// If Tools is nil, the field is omitted from the JSON. +// If Tools is an empty slice, it's included in the JSON as an empty array ([]). +// If Tools is populated, it's included in the JSON with the elements. +func (a AssistantRequest) MarshalJSON() ([]byte, error) { + type Alias AssistantRequest + assistantAlias := &struct { + Tools *[]AssistantTool `json:"tools,omitempty"` + *Alias + }{ + Alias: (*Alias)(&a), + } + + if a.Tools != nil { + assistantAlias.Tools = &a.Tools + } + + return json.Marshal(assistantAlias) +} + +// AssistantsList is a list of assistants. +type AssistantsList struct { + Assistants []Assistant `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` + httpHeader +} + +type AssistantDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +type AssistantFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` + + httpHeader +} + +type AssistantFileRequest struct { + FileID string `json:"file_id"` +} + +type AssistantFilesList struct { + AssistantFiles []AssistantFile `json:"data"` + + httpHeader +} + +// CreateAssistant creates a new assistant. +func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistant retrieves an assistant. +func (c *Client) RetrieveAssistant( + ctx context.Context, + assistantID string, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyAssistant modifies an assistant. +func (c *Client) ModifyAssistant( + ctx context.Context, + assistantID string, + request AssistantRequest, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistant deletes an assistant. +func (c *Client) DeleteAssistant( + ctx context.Context, + assistantID string, +) (response AssistantDeleteResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListAssistants Lists the currently available assistants. +func (c *Client) ListAssistants( + ctx context.Context, + limit *int, + order *string, + after *string, + before *string, +) (response AssistantsList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CreateAssistantFile creates a new assistant file. +func (c *Client) CreateAssistantFile( + ctx context.Context, + assistantID string, + request AssistantFileRequest, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistantFile retrieves an assistant file. +func (c *Client) RetrieveAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistantFile deletes an existing file. +func (c *Client) DeleteAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, nil) + return +} + +// ListAssistantFiles Lists the currently available files for an assistant. +func (c *Client) ListAssistantFiles( + ctx context.Context, + assistantID string, + limit *int, + order *string, + after *string, + before *string, +) (response AssistantFilesList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/assistant_test.go b/assistant_test.go new file mode 100644 index 000000000..40de0e50f --- /dev/null +++ b/assistant_test.go @@ -0,0 +1,447 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestAssistant Tests the assistant endpoint of the API using the mocked server. +func TestAssistant(t *testing.T) { + assistantID := "asst_abc123" + assistantName := "Ambrogio" + assistantDescription := "Ambrogio is a friendly assistant." + assistantInstructions := `You are a personal math tutor. +When asked a question, write and run Python code to answer the question.` + assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files/"+assistantFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "assistant.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFilesList{ + AssistantFiles: []openai.AssistantFile{ + { + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.AssistantFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: request.FileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.Assistant + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantsList{ + LastID: &assistantID, + FirstID: &assistantID, + Assistants: []openai.Assistant{ + { + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + t.Run("create_assistant", func(t *testing.T) { + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") + }) + + t.Run("retrieve_assistant", func(t *testing.T) { + _, err := client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + }) + + t.Run("delete_assistant", func(t *testing.T) { + _, err := client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") + }) + + t.Run("list_assistant", func(t *testing.T) { + _, err := client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + }) + + t.Run("create_assistant_file", func(t *testing.T) { + _, err := client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + }) + + t.Run("list_assistant_files", func(t *testing.T) { + _, err := client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") + }) + + t.Run("retrieve_assistant_file", func(t *testing.T) { + _, err := client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + }) + + t.Run("delete_assistant_file", func(t *testing.T) { + err := client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") + }) + + t.Run("modify_assistant_no_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools != nil { + t.Errorf("expected nil got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_with_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: []openai.AssistantTool{{Type: openai.AssistantToolTypeFunction}}, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil || len(assistant.Tools) != 1 { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_empty_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: make([]openai.AssistantTool, 0), + }) + + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) +} + +func TestAzureAssistant(t *testing.T) { + assistantID := "asst_abc123" + assistantName := "Ambrogio" + assistantDescription := "Ambrogio is a friendly assistant." + assistantInstructions := `You are a personal math tutor. +When asked a question, write and run Python code to answer the question.` + assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupAzureTestServer() + defer teardown() + + server.RegisterHandler( + "/openai/assistants/"+assistantID+"/files/"+assistantFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "assistant.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants/"+assistantID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFilesList{ + AssistantFiles: []openai.AssistantFile{ + { + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.AssistantFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: request.FileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants/"+assistantID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantsList{ + LastID: &assistantID, + FirstID: &assistantID, + Assistants: []openai.Assistant{ + { + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") + + _, err = client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + + _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + _, err = client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") + + _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + + _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + + _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") + + _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + + err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") +} diff --git a/audio.go b/audio.go index d22daf98c..f321f93d6 100644 --- a/audio.go +++ b/audio.go @@ -4,8 +4,11 @@ import ( "bytes" "context" "fmt" + "io" "net/http" "os" + + utils "github.com/sashabaranov/go-openai/internal" ) // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. @@ -17,25 +20,76 @@ const ( type AudioResponseFormat string const ( - AudioResponseFormatJSON AudioResponseFormat = "json" - AudioResponseFormatSRT AudioResponseFormat = "srt" - AudioResponseFormatVTT AudioResponseFormat = "vtt" + AudioResponseFormatJSON AudioResponseFormat = "json" + AudioResponseFormatText AudioResponseFormat = "text" + AudioResponseFormatSRT AudioResponseFormat = "srt" + AudioResponseFormatVerboseJSON AudioResponseFormat = "verbose_json" + AudioResponseFormatVTT AudioResponseFormat = "vtt" +) + +type TranscriptionTimestampGranularity string + +const ( + TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" + TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" ) // AudioRequest represents a request structure for audio API. -// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { - Model string - FilePath string - Prompt string // For translation, it should be in English - Temperature float32 - Language string // For translation, just do not use it. It seems "en" works, not confirmed... - Format AudioResponseFormat + Model string + + // FilePath is either an existing file in your filesystem or a filename representing the contents of Reader. + FilePath string + + // Reader is an optional io.Reader when you do not want to use an existing file. + Reader io.Reader + + Prompt string + Temperature float32 + Language string // Only for transcription. + Format AudioResponseFormat + TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription. } // AudioResponse represents a response structure for audio API. type AudioResponse struct { + Task string `json:"task"` + Language string `json:"language"` + Duration float64 `json:"duration"` + Segments []struct { + ID int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` + Transient bool `json:"transient"` + } `json:"segments"` + Words []struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` + } `json:"words"` Text string `json:"text"` + + httpHeader +} + +type audioTextResponse struct { + Text string `json:"text"` + + httpHeader +} + +func (r *audioTextResponse) ToAudioResponse() AudioResponse { + return AudioResponse{ + Text: r.Text, + httpHeader: r.httpHeader, + } } // CreateTranscription — API call to create a transcription. Returns transcribed text. @@ -68,16 +122,23 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(&formBody), + withContentType(builder.FormDataContentType()), + ) if err != nil { return AudioResponse{}, err } - req.Header.Add("Content-Type", builder.formDataContentType()) if request.HasJSONResponse() { err = c.sendRequest(req, &response) } else { - err = c.sendRequest(req, &response.Text) + var textResponse audioTextResponse + err = c.sendRequest(req, &textResponse) + response = textResponse.ToAudioResponse() } if err != nil { return AudioResponse{}, err @@ -87,31 +148,25 @@ func (c *Client) callAudioAPI( // HasJSONResponse returns true if the response format is JSON. func (r AudioRequest) HasJSONResponse() bool { - return r.Format == "" || r.Format == AudioResponseFormatJSON + return r.Format == "" || r.Format == AudioResponseFormatJSON || r.Format == AudioResponseFormatVerboseJSON } // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. -func audioMultipartForm(request AudioRequest, b formBuilder) error { - f, err := os.Open(request.FilePath) - if err != nil { - return fmt.Errorf("opening audio file: %w", err) - } - defer f.Close() - - err = b.createFormFile("file", f) +func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { + err := createFileField(request, b) if err != nil { - return fmt.Errorf("creating form file: %w", err) + return err } - err = b.writeField("model", request.Model) + err = b.WriteField("model", request.Model) if err != nil { return fmt.Errorf("writing model name: %w", err) } // Create a form field for the prompt (if provided) if request.Prompt != "" { - err = b.writeField("prompt", request.Prompt) + err = b.WriteField("prompt", request.Prompt) if err != nil { return fmt.Errorf("writing prompt: %w", err) } @@ -119,7 +174,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { // Create a form field for the format (if provided) if request.Format != "" { - err = b.writeField("response_format", string(request.Format)) + err = b.WriteField("response_format", string(request.Format)) if err != nil { return fmt.Errorf("writing format: %w", err) } @@ -127,7 +182,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { // Create a form field for the temperature (if provided) if request.Temperature != 0 { - err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature)) + err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature)) if err != nil { return fmt.Errorf("writing temperature: %w", err) } @@ -135,12 +190,45 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { // Create a form field for the language (if provided) if request.Language != "" { - err = b.writeField("language", request.Language) + err = b.WriteField("language", request.Language) if err != nil { return fmt.Errorf("writing language: %w", err) } } + if len(request.TimestampGranularities) > 0 { + for _, tg := range request.TimestampGranularities { + err = b.WriteField("timestamp_granularities[]", string(tg)) + if err != nil { + return fmt.Errorf("writing timestamp_granularities[]: %w", err) + } + } + } + // Close the multipart writer - return b.close() + return b.Close() +} + +// createFileField creates the "file" form field from either an existing file or by using the reader. +func createFileField(request AudioRequest, b utils.FormBuilder) error { + if request.Reader != nil { + err := b.CreateFormFileReader("file", request.Reader, request.FilePath) + if err != nil { + return fmt.Errorf("creating form using reader: %w", err) + } + return nil + } + + f, err := os.Open(request.FilePath) + if err != nil { + return fmt.Errorf("opening audio file: %w", err) + } + defer f.Close() + + err = b.CreateFormFile("file", f) + if err != nil { + return fmt.Errorf("creating form file: %w", err) + } + + return nil } diff --git a/audio_api_test.go b/audio_api_test.go new file mode 100644 index 000000000..6c6a35643 --- /dev/null +++ b/audio_api_test.go @@ -0,0 +1,160 @@ +package openai_test + +import ( + "bytes" + "context" + "errors" + "io" + "mime" + "mime/multipart" + "net/http" + "path/filepath" + "strings" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. +func TestAudio(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) + server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) + + testcases := []struct { + name string + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) + }{ + { + "transcribe", + client.CreateTranscription, + }, + { + "translate", + client.CreateTranslation, + }, + } + + ctx := context.Background() + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(t.TempDir(), "fake.mp3") + test.CreateTestFile(t, path) + + req := openai.AudioRequest{ + FilePath: path, + Model: "whisper-3", + } + _, err := tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) + + t.Run(tc.name+" (with reader)", func(t *testing.T) { + req := openai.AudioRequest{ + FilePath: "fake.webm", + Reader: bytes.NewBuffer([]byte(`some webm binary data`)), + Model: "whisper-3", + } + _, err := tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) + } +} + +func TestAudioWithOptionalArgs(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) + server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) + + testcases := []struct { + name string + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) + }{ + { + "transcribe", + client.CreateTranscription, + }, + { + "translate", + client.CreateTranslation, + }, + } + + ctx := context.Background() + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(t.TempDir(), "fake.mp3") + test.CreateTestFile(t, path) + + req := openai.AudioRequest{ + FilePath: path, + Model: "whisper-3", + Prompt: "用简体中文", + Temperature: 0.5, + Language: "zh", + Format: openai.AudioResponseFormatSRT, + TimestampGranularities: []openai.TranscriptionTimestampGranularity{ + openai.TranscriptionTimestampGranularitySegment, + openai.TranscriptionTimestampGranularityWord, + }, + } + _, err := tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) + } +} + +// handleAudioEndpoint Handles the completion endpoint by the test server. +func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + + // audio endpoints only accept POST requests + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } + + mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, "failed to parse media type", http.StatusBadRequest) + return + } + + if !strings.HasPrefix(mediaType, "multipart") { + http.Error(w, "request is not multipart", http.StatusBadRequest) + } + + boundary, ok := params["boundary"] + if !ok { + http.Error(w, "no boundary in params", http.StatusBadRequest) + return + } + + fileData := &bytes.Buffer{} + mr := multipart.NewReader(r.Body, boundary) + part, err := mr.NextPart() + if err != nil && errors.Is(err, io.EOF) { + http.Error(w, "error accessing file", http.StatusBadRequest) + return + } + if _, err = io.Copy(fileData, part); err != nil { + http.Error(w, "failed to copy file", http.StatusInternalServerError) + return + } + + if len(fileData.Bytes()) == 0 { + w.WriteHeader(http.StatusInternalServerError) + http.Error(w, "received empty file data", http.StatusBadRequest) + return + } + + if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { + http.Error(w, "failed to write body", http.StatusInternalServerError) + return + } +} diff --git a/audio_test.go b/audio_test.go index daf51f28c..51b3f465d 100644 --- a/audio_test.go +++ b/audio_test.go @@ -2,214 +2,240 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" "errors" "fmt" "io" - "mime" - "mime/multipart" "net/http" "os" "path/filepath" - "strings" + "testing" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" - - "context" - "testing" ) -// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. -func TestAudio(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) - server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - testcases := []struct { - name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) - }{ - { - "transcribe", - client.CreateTranscription, - }, - { - "translate", - client.CreateTranslation, +func TestAudioWithFailingFormBuilder(t *testing.T) { + path := filepath.Join(t.TempDir(), "fake.mp3") + test.CreateTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + Prompt: "test", + Temperature: 0.5, + Language: "en", + Format: AudioResponseFormatSRT, + TimestampGranularities: []TranscriptionTimestampGranularity{ + TranscriptionTimestampGranularitySegment, + TranscriptionTimestampGranularityWord, }, } - ctx := context.Background() + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder := &mockFormBuilder{} - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() + mockBuilder.mockCreateFormFile = func(string, *os.File) error { + return mockFailedErr + } + err := audioMultipartForm(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") - test.CreateTestFile(t, path) + mockBuilder.mockCreateFormFile = func(string, *os.File) error { + return nil + } - req := AudioRequest{ - FilePath: path, - Model: "whisper-3", - } - _, err = tc.createFn(ctx, req) - checks.NoError(t, err, "audio API error") - }) + var failForField string + mockBuilder.mockWriteField = func(fieldname, _ string) error { + if fieldname == failForField { + return mockFailedErr + } + return nil } -} -func TestAudioWithOptionalArgs(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) - server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - testcases := []struct { - name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) - }{ - { - "transcribe", - client.CreateTranscription, - }, - { - "translate", - client.CreateTranslation, - }, + failOn := []string{"model", "prompt", "temperature", "language", "response_format", "timestamp_granularities[]"} + for _, failingField := range failOn { + failForField = failingField + mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) + + err = audioMultipartForm(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") } +} + +func TestCreateFileField(t *testing.T) { + t.Run("createFileField failing file", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "fake.mp3") + test.CreateTestFile(t, path) - ctx := context.Background() + req := AudioRequest{ + FilePath: path, + } - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder := &mockFormBuilder{ + mockCreateFormFile: func(string, *os.File) error { + return mockFailedErr + }, + } - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") - test.CreateTestFile(t, path) + err := createFileField(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails") + }) - req := AudioRequest{ - FilePath: path, - Model: "whisper-3", - Prompt: "用简体中文", - Temperature: 0.5, - Language: "zh", - Format: AudioResponseFormatSRT, - } - _, err = tc.createFn(ctx, req) - checks.NoError(t, err, "audio API error") - }) - } + t.Run("createFileField failing reader", func(t *testing.T) { + req := AudioRequest{ + FilePath: "test.wav", + Reader: bytes.NewBuffer([]byte(`wav test contents`)), + } + + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { + return mockFailedErr + }, + } + + err := createFileField(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails") + }) + + t.Run("createFileField failing open", func(t *testing.T) { + req := AudioRequest{ + FilePath: "non_existing_file.wav", + } + + mockBuilder := &mockFormBuilder{} + + err := createFileField(req, mockBuilder) + checks.HasError(t, err, "createFileField using file should return error when open file fails") + }) } -// handleAudioEndpoint Handles the completion endpoint by the test server. -func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { - var err error +// failingFormBuilder always returns an error when creating form files. +type failingFormBuilder struct{ err error } - // audio endpoints only accept POST requests - if r.Method != "POST" { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } +func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error { + return f.err +} - mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) - if err != nil { - http.Error(w, "failed to parse media type", http.StatusBadRequest) - return - } +func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error { + return f.err +} - if !strings.HasPrefix(mediaType, "multipart") { - http.Error(w, "request is not multipart", http.StatusBadRequest) - } +func (f *failingFormBuilder) WriteField(_, _ string) error { + return nil +} - boundary, ok := params["boundary"] - if !ok { - http.Error(w, "no boundary in params", http.StatusBadRequest) - return +func (f *failingFormBuilder) Close() error { + return nil +} + +func (f *failingFormBuilder) FormDataContentType() string { + return "multipart/form-data" +} + +// failingAudioRequestBuilder simulates an error during HTTP request construction. +type failingAudioRequestBuilder struct{ err error } + +func (f *failingAudioRequestBuilder) Build( + _ context.Context, + _, _ string, + _ any, + _ http.Header, +) (*http.Request, error) { + return nil, f.err +} + +// errorHTTPClient always returns an error when making HTTP calls. +type errorHTTPClient struct{ err error } + +func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) { + return nil, e.err +} + +func TestCallAudioAPIMultipartFormError(t *testing.T) { + client := NewClient("test-token") + errForm := errors.New("mock create form file failure") + // Override form builder to force an error during multipart form creation. + client.createFormBuilder = func(_ io.Writer) utils.FormBuilder { + return &failingFormBuilder{err: errForm} } - fileData := &bytes.Buffer{} - mr := multipart.NewReader(r.Body, boundary) - part, err := mr.NextPart() - if err != nil && errors.Is(err, io.EOF) { - http.Error(w, "error accessing file", http.StatusBadRequest) - return + // Provide a reader so createFileField uses the reader path (no file open). + req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") } - if _, err = io.Copy(fileData, part); err != nil { - http.Error(w, "failed to copy file", http.StatusInternalServerError) - return + if !errors.Is(err, errForm) { + t.Errorf("expected error %v, got %v", errForm, err) } +} - if len(fileData.Bytes()) == 0 { - w.WriteHeader(http.StatusInternalServerError) - http.Error(w, "received empty file data", http.StatusBadRequest) - return +func TestCallAudioAPINewRequestError(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) } - if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { - http.Error(w, "failed to write body", http.StatusInternalServerError) - return + errBuild := errors.New("mock build failure") + client.requestBuilder = &failingAudioRequestBuilder{err: errBuild} + + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errBuild) { + t.Errorf("expected error %v, got %v", errBuild, err) } } -func TestAudioWithFailingFormBuilder(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - path := filepath.Join(dir, "fake.mp3") - test.CreateTestFile(t, path) - - req := AudioRequest{ - FilePath: path, - Prompt: "test", - Temperature: 0.5, - Language: "en", - Format: AudioResponseFormatSRT, +func TestCallAudioAPISendRequestErrorJSON(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) } - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder := &mockFormBuilder{} + errHTTP := errors.New("mock HTTPClient failure") + // Override HTTP client to simulate a network error. + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return mockFailedErr + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") } - err := audioMultipartForm(req, mockBuilder) - checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") - - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return nil + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) } +} - var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil +func TestCallAudioAPISendRequestErrorText(t *testing.T) { + client := NewClient("test-token") + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) } - failOn := []string{"model", "prompt", "temperature", "language", "response_format"} - for _, failingField := range failOn { - failForField = failingField - mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) + errHTTP := errors.New("mock HTTPClient failure") + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} - err = audioMultipartForm(req, mockBuilder) - checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") + // Use a non-JSON response format to exercise the text path. + req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) } } diff --git a/batch.go b/batch.go new file mode 100644 index 000000000..3c1a9d0d7 --- /dev/null +++ b/batch.go @@ -0,0 +1,271 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +const batchesSuffix = "/batches" + +type BatchEndpoint string + +const ( + BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" + BatchEndpointCompletions BatchEndpoint = "/v1/completions" + BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" +) + +type BatchLineItem interface { + MarshalBatchLineItem() []byte +} + +type BatchChatCompletionRequest struct { + CustomID string `json:"custom_id"` + Body ChatCompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchCompletionRequest struct { + CustomID string `json:"custom_id"` + Body CompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchEmbeddingRequest struct { + CustomID string `json:"custom_id"` + Body EmbeddingRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type Batch struct { + ID string `json:"id"` + Object string `json:"object"` + Endpoint BatchEndpoint `json:"endpoint"` + Errors *struct { + Object string `json:"object,omitempty"` + Data []struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + Line *int `json:"line,omitempty"` + } `json:"data"` + } `json:"errors"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID *string `json:"output_file_id"` + ErrorFileID *string `json:"error_file_id"` + CreatedAt int `json:"created_at"` + InProgressAt *int `json:"in_progress_at"` + ExpiresAt *int `json:"expires_at"` + FinalizingAt *int `json:"finalizing_at"` + CompletedAt *int `json:"completed_at"` + FailedAt *int `json:"failed_at"` + ExpiredAt *int `json:"expired_at"` + CancellingAt *int `json:"cancelling_at"` + CancelledAt *int `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata map[string]any `json:"metadata"` +} + +type BatchRequestCounts struct { + Total int `json:"total"` + Completed int `json:"completed"` + Failed int `json:"failed"` +} + +type CreateBatchRequest struct { + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` +} + +type BatchResponse struct { + httpHeader + Batch +} + +// CreateBatch — API call to Create batch. +func (c *Client) CreateBatch( + ctx context.Context, + request CreateBatchRequest, +) (response BatchResponse, err error) { + if request.CompletionWindow == "" { + request.CompletionWindow = "24h" + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type UploadBatchFileRequest struct { + FileName string + Lines []BatchLineItem +} + +func (r *UploadBatchFileRequest) MarshalJSONL() []byte { + buff := bytes.Buffer{} + for i, line := range r.Lines { + if i != 0 { + buff.Write([]byte("\n")) + } + buff.Write(line.MarshalBatchLineItem()) + } + return buff.Bytes() +} + +func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { + r.Lines = append(r.Lines, BatchChatCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointChatCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { + r.Lines = append(r.Lines, BatchCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { + r.Lines = append(r.Lines, BatchEmbeddingRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointEmbeddings, + }) +} + +// UploadBatchFile — upload batch file. +func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { + if request.FileName == "" { + request.FileName = "@batchinput.jsonl" + } + return c.CreateFileBytes(ctx, FileBytesRequest{ + Name: request.FileName, + Bytes: request.MarshalJSONL(), + Purpose: PurposeBatch, + }) +} + +type CreateBatchWithUploadFileRequest struct { + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + UploadBatchFileRequest +} + +// CreateBatchWithUploadFile — API call to Create batch with upload file. +func (c *Client) CreateBatchWithUploadFile( + ctx context.Context, + request CreateBatchWithUploadFileRequest, +) (response BatchResponse, err error) { + var file File + file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ + FileName: request.FileName, + Lines: request.Lines, + }) + if err != nil { + return + } + return c.CreateBatch(ctx, CreateBatchRequest{ + InputFileID: file.ID, + Endpoint: request.Endpoint, + CompletionWindow: request.CompletionWindow, + Metadata: request.Metadata, + }) +} + +// RetrieveBatch — API call to Retrieve batch. +func (c *Client) RetrieveBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +// CancelBatch — API call to Cancel batch. +func (c *Client) CancelBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +type ListBatchResponse struct { + httpHeader + Object string `json:"object"` + Data []Batch `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// ListBatch API call to List batch. +func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 000000000..f4714f4eb --- /dev/null +++ b/batch_test.go @@ -0,0 +1,368 @@ +package openai_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestUploadBatchFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.UploadBatchFileRequest{} + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.UploadBatchFile(context.Background(), req) + checks.NoError(t, err, "UploadBatchFile error") +} + +func TestCreateBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + _, err := client.CreateBatch(context.Background(), openai.CreateBatchRequest{ + InputFileID: "file-abc", + Endpoint: openai.BatchEndpointChatCompletions, + CompletionWindow: "24h", + }) + checks.NoError(t, err, "CreateBatch error") +} + +func TestCreateBatchWithUploadFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + req := openai.CreateBatchWithUploadFileRequest{ + Endpoint: openai.BatchEndpointChatCompletions, + } + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.CreateBatchWithUploadFile(context.Background(), req) + checks.NoError(t, err, "CreateBatchWithUploadFile error") +} + +func TestRetrieveBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1", handleRetrieveBatchEndpoint) + _, err := client.RetrieveBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestCancelBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1/cancel", handleCancelBatchEndpoint) + _, err := client.CancelBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestListBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + after := "batch_abc123" + limit := 10 + _, err := client.ListBatch(context.Background(), &after, &limit) + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestUploadBatchFileRequest_AddChatCompletion(t *testing.T) { + type args struct { + customerID string + body openai.ChatCompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + { + customerID: "req-2", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddChatCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddCompletion(t *testing.T) { + type args struct { + customerID string + body openai.CompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + { + customerID: "req-2", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { + type args struct { + customerID string + body openai.EmbeddingRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.EmbeddingRequest{ + Model: openai.GPT3Dot5Turbo, + Input: []string{"Hello", "World"}, + }, + }, + { + customerID: "req-2", + body: openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"Hello", "World"}, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddEmbedding(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func handleBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } else if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job" + } + } + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + }`) + } +} + +func handleRetrieveBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} + +func handleCancelBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} diff --git a/chat.go b/chat.go index c09861c8c..0aa018715 100644 --- a/chat.go +++ b/chat.go @@ -2,8 +2,11 @@ package openai import ( "context" + "encoding/json" "errors" "net/http" + + "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -11,54 +14,452 @@ const ( ChatMessageRoleSystem = "system" ChatMessageRoleUser = "user" ChatMessageRoleAssistant = "assistant" + ChatMessageRoleFunction = "function" + ChatMessageRoleTool = "tool" + ChatMessageRoleDeveloper = "developer" ) +const chatCompletionsSuffix = "/chat/completions" + var ( ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll + ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously") +) + +type Hate struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} +type SelfHarm struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} +type Sexual struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} +type Violence struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} + +type JailBreak struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + +type Profanity struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + +type ContentFilterResults struct { + Hate Hate `json:"hate,omitempty"` + SelfHarm SelfHarm `json:"self_harm,omitempty"` + Sexual Sexual `json:"sexual,omitempty"` + Violence Violence `json:"violence,omitempty"` + JailBreak JailBreak `json:"jailbreak,omitempty"` + Profanity Profanity `json:"profanity,omitempty"` +} + +type PromptAnnotation struct { + PromptIndex int `json:"prompt_index,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` +} + +type ImageURLDetail string + +const ( + ImageURLDetailHigh ImageURLDetail = "high" + ImageURLDetailLow ImageURLDetail = "low" + ImageURLDetailAuto ImageURLDetail = "auto" ) +type ChatMessageImageURL struct { + URL string `json:"url,omitempty"` + Detail ImageURLDetail `json:"detail,omitempty"` +} + +type ChatMessagePartType string + +const ( + ChatMessagePartTypeText ChatMessagePartType = "text" + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" +) + +type ChatMessagePart struct { + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` +} + type ChatCompletionMessage struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart // This property isn't in the official documentation, but it's in // the documentation for the official library for python: // - https://github.com/openai/openai-python/blob/main/chatml.md // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` + + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` + + FunctionCall *FunctionCall `json:"function_call,omitempty"` + + // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + // For Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool. + ToolCallID string `json:"tool_call_id,omitempty"` +} + +func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { + if m.Content != "" && m.MultiContent != nil { + return nil, ErrContentFieldsMisused + } + if len(m.MultiContent) > 0 { + msg := struct { + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) + } + + msg := struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) +} + +func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + + if err := json.Unmarshal(bs, &msg); err == nil { + *m = ChatCompletionMessage(msg) + return nil + } + multiMsg := struct { + Role string `json:"role"` + Content string + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &multiMsg); err != nil { + return err + } + *m = ChatCompletionMessage(multiMsg) + return nil +} + +type ToolCall struct { + // Index is not nil only in chat completion chunk object + Index *int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type ToolType `json:"type"` + Function FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + // call function with arguments in JSON format + Arguments string `json:"arguments,omitempty"` +} + +type ChatCompletionResponseFormatType string + +const ( + ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" + ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" +) + +type ChatCompletionResponseFormat struct { + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` +} + +type ChatCompletionResponseFormatJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.Marshaler `json:"schema"` + Strict bool `json:"strict"` +} + +func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) error { + type rawJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.RawMessage `json:"schema"` + Strict bool `json:"strict"` + } + var raw rawJSONSchema + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Name = raw.Name + r.Description = raw.Description + r.Strict = raw.Strict + if len(raw.Schema) > 0 && string(raw.Schema) != "null" { + var d jsonschema.Definition + err := json.Unmarshal(raw.Schema, &d) + if err != nil { + return err + } + r.Schema = &d + } + return nil +} + +// ChatCompletionRequestExtensions contains third-party OpenAI API extensions +// (e.g., vendor-specific implementations like vLLM). +type ChatCompletionRequestExtensions struct { + // GuidedChoice is a vLLM-specific extension that restricts the model's output + // to one of the predefined string choices provided in this field. This feature + // is used to constrain the model's responses to a controlled set of options, + // ensuring predictable and consistent outputs in scenarios where specific + // choices are required. + GuidedChoice []string `json:"guided_choice,omitempty"` } // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + // MaxTokens The maximum number of tokens that can be generated in the chat completion. + // This value can be used to control costs for text generated via API. + // Deprecated: use MaxCompletionTokens. Not compatible with o1-series models. + // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens + MaxTokens int `json:"max_tokens,omitempty"` + // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, + // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. + // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` + // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias + LogitBias map[string]int `json:"logit_bias,omitempty"` + // LogProbs indicates whether to return log probabilities of the output tokens or not. + // If true, returns the log probabilities of each output token returned in the content of message. + // This option is currently not available on the gpt-4-vision-preview model. + LogProbs bool `json:"logprobs,omitempty"` + // TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each + // token position, each with an associated log probability. + // logprobs must be set to true if this parameter is used. + TopLogProbs int `json:"top_logprobs,omitempty"` + User string `json:"user,omitempty"` + // Deprecated: use Tools instead. + Functions []FunctionDefinition `json:"functions,omitempty"` + // Deprecated: use ToolChoice instead. + FunctionCall any `json:"function_call,omitempty"` + Tools []Tool `json:"tools,omitempty"` + // This can be either a string or an ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + // Controls effort on reasoning for reasoning models. It can be set to "low", "medium", or "high". + ReasoningEffort string `json:"reasoning_effort,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` + // Configuration for a predicted output. + Prediction *Prediction `json:"prediction,omitempty"` + // ChatTemplateKwargs provides a way to add non-standard parameters to the request body. + // Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} + // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes + ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` + // Specifies the latency tier to use for processing the request. + ServiceTier ServiceTier `json:"service_tier,omitempty"` + // Verbosity determines how many output tokens are generated. Lowering the number of + // tokens reduces overall latency. It can be set to "low", "medium", or "high". + // Note: This field is only confirmed to work with gpt-5, gpt-5-mini and gpt-5-nano. + // Also, it is not in the API reference of chat completion at the time of writing, + // though it is supported by the API. + Verbosity string `json:"verbosity,omitempty"` + // A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. + // The IDs should be a string that uniquely identifies each user. + // We recommend hashing their username or email address, in order to avoid sending us any identifying information. + // https://platform.openai.com/docs/api-reference/chat/create#chat_create-safety_identifier + SafetyIdentifier string `json:"safety_identifier,omitempty"` + // Embedded struct for non-OpenAI extensions + ChatCompletionRequestExtensions +} + +type StreamOptions struct { + // If set, an additional chunk will be streamed before the data: [DONE] message. + // The usage field on this chunk shows the token usage statistics for the entire request, + // and the choices field will always be an empty array. + // All other chunks will also include a usage field, but with a null value. + IncludeUsage bool `json:"include_usage,omitempty"` +} + +type ToolType string + +const ( + ToolTypeFunction ToolType = "function" +) + +type Tool struct { + Type ToolType `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` +} + +type ToolChoice struct { + Type ToolType `json:"type"` + Function ToolFunction `json:"function,omitempty"` +} + +type ToolFunction struct { + Name string `json:"name"` +} + +type FunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` + // Parameters is an object describing the function. + // You can pass json.RawMessage to describe the schema, + // or you can pass in a struct which serializes to the proper JSON schema. + // The jsonschema package is provided for convenience, but you should + // consider another specialized library if you require more complex schemas. + Parameters any `json:"parameters"` +} + +// Deprecated: use FunctionDefinition instead. +type FunctionDefine = FunctionDefinition + +type TopLogProbs struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` +} + +// LogProb represents the probability information for a token. +type LogProb struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` // Omitting the field if it is null + // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. + // In rare cases, there may be fewer than the number of requested top_logprobs returned. + TopLogProbs []TopLogProbs `json:"top_logprobs"` +} + +// LogProbs is the top-level structure containing the log probability information. +type LogProbs struct { + // Content is a list of message content tokens with log probability information. + Content []LogProb `json:"content"` +} + +type Prediction struct { + Content string `json:"content"` + Type string `json:"type"` +} + +type FinishReason string + +const ( + FinishReasonStop FinishReason = "stop" + FinishReasonLength FinishReason = "length" + FinishReasonFunctionCall FinishReason = "function_call" + FinishReasonToolCalls FinishReason = "tool_calls" + FinishReasonContentFilter FinishReason = "content_filter" + FinishReasonNull FinishReason = "null" +) + +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierPriority ServiceTier = "priority" +) + +func (r FinishReason) MarshalJSON() ([]byte, error) { + if r == FinishReasonNull || r == "" { + return []byte("null"), nil + } + return []byte(`"` + string(r) + `"`), nil // best effort to not break future API changes } type ChatCompletionChoice struct { - Index int `json:"index"` - Message ChatCompletionMessage `json:"message"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` + // FinishReason + // stop: API returned complete message, + // or a message terminated by one of the stop sequences provided via the stop parameter + // length: Incomplete model output due to max_tokens parameter or token limit + // function_call: The model decided to call a function + // content_filter: Omitted content due to a flag from our content filters + // null: API response still in progress or incomplete + FinishReason FinishReason `json:"finish_reason"` + LogProbs *LogProbs `json:"logprobs,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` + + httpHeader } // CreateChatCompletion — API call to Create a completion for the chat message. @@ -71,13 +472,23 @@ func (c *Client) CreateChatCompletion( return } - urlSuffix := "/chat/completions" + urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrChatCompletionInvalidModel return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { + return + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 9ed0bc70a..80d16cc63 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -1,28 +1,68 @@ package openai import ( - "bufio" "context" "net/http" ) type ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Refusal string `json:"refusal,omitempty"` + + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` +} + +type ChatCompletionStreamChoiceLogprobs struct { + Content []ChatCompletionTokenLogprob `json:"content,omitempty"` + Refusal []ChatCompletionTokenLogprob `json:"refusal,omitempty"` +} + +type ChatCompletionTokenLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes,omitempty"` + Logprob float64 `json:"logprob,omitempty"` + TopLogprobs []ChatCompletionTokenLogprobTopLogprob `json:"top_logprobs"` +} + +type ChatCompletionTokenLogprobTopLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes"` + Logprob float64 `json:"logprob"` } type ChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + Logprobs *ChatCompletionStreamChoiceLogprobs `json:"logprobs,omitempty"` + FinishReason FinishReason `json:"finish_reason"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` +} + +type PromptFilterResult struct { + Index int `json:"index"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } type ChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionStreamChoice `json:"choices"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` + SystemFingerprint string `json:"system_fingerprint"` + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + // An optional field that will only be present when you set stream_options: {"include_usage": true} in your request. + // When present, it contains a null value except for the last chunk which contains the token usage statistics + // for the entire request. + Usage *Usage `json:"usage,omitempty"` } // ChatCompletionStream @@ -39,34 +79,34 @@ func (c *Client) CreateChatCompletionStream( ctx context.Context, request ChatCompletionRequest, ) (stream *ChatCompletionStream, err error) { - urlSuffix := "/chat/completions" + urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrChatCompletionInvalidModel return } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) - if err != nil { + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { return } - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { - return - } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { - return nil, c.handleErrorResp(resp) + return nil, err } + resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req) + if err != nil { + return + } stream = &ChatCompletionStream{ - streamReader: &streamReader[ChatCompletionStreamResponse]{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, - errAccumulator: newErrorAccumulator(), - unmarshaler: &jsonUnmarshaler{}, - }, + streamReader: resp, } return } diff --git a/chat_stream_test.go b/chat_stream_test.go index afcb86d5e..eabb0f3a2 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,55 +1,57 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" + "fmt" "io" "net/http" - "net/http/httptest" + "strconv" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletionStream(ctx, req) - if !errors.Is(err, ErrChatCompletionInvalidModel) { + if !errors.Is(err, openai.ErrChatCompletionInvalidModel) { t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) } } func TestCreateChatCompletionStream(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: done\n")...) @@ -57,45 +59,32 @@ func TestCreateChatCompletionStream(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Stream: true, - } - - stream, err := client.CreateChatCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []ChatCompletionStreamResponse{ + expectedResponses := []openai.ChatCompletionStreamResponse{ { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response1", }, FinishReason: "max_tokens", @@ -103,13 +92,14 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response2", }, FinishReason: "max_tokens", @@ -143,7 +133,9 @@ func TestCreateChatCompletionStream(t *testing.T) { } func TestCreateChatCompletionStreamError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -164,40 +156,146 @@ func TestCreateChatCompletionStreamError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + _, streamErr := stream.Recv() + checks.HasError(t, streamErr, "stream.Recv() did not return error") + + var apiErr *openai.APIError + if !errors.As(streamErr, &apiErr) { + t.Errorf("stream.Recv() did not return APIError") } + t.Logf("%+v\n", apiErr) +} - client := NewClientWithConfig(config) - ctx := context.Background() +func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set(xCustomHeader, xCustomHeaderValue) + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) - request := ChatCompletionRequest{ + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + value := stream.Header().Get(xCustomHeader) + if value != xCustomHeaderValue { + t.Errorf("expected %s to be %s", xCustomHeaderValue, value) } +} - stream, err := client.CreateChatCompletionStream(ctx, request) +func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + headers := stream.GetRateLimitHeaders() + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + +func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -205,8 +303,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) { } func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -220,43 +319,456 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + var apiErr *openai.APIError + if !errors.As(err, &apiErr) { + t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + +func TestCreateChatCompletionStreamWithRefusal(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":" World"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT4oMini20240718, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: " World", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + }, } - client := NewClientWithConfig(config) - ctx := context.Background() + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) - request := ChatCompletionRequest{ - MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamWithLogprobs(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":{"content":[],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":{"content":[{"token":"Hello","logprob":-0.000020458236,"bytes":[72,101,108,108,111],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":" World"},"logprobs":{"content":[{"token":" World","logprob":-0.00055303273,"bytes":[32,87,111,114,108,100],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{}, + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: "Hello", + Logprob: -0.000020458236, + Bytes: []int64{72, 101, 108, 108, 111}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " World", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: " World", + Logprob: -0.00055303273, + Bytes: []int64{32, 87, 111, 114, 108, 100}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, } - var apiErr *APIError - _, err := client.CreateChatCompletionStream(ctx, request) + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { + wantCode := "429" + wantMessage := "Requests to the Creates a completion for the chat message Operation under Azure OpenAI API " + + "version 2023-03-15-preview have exceeded token rate limit of your current OpenAI S0 pricing tier. " + + "Please retry after 20 seconds. " + + "Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit." + + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", + func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + // Send test responses + dataBytes := []byte(`{"error": { "code": "` + wantCode + `", "message": "` + wantMessage + `"}}`) + _, err := w.Write(dataBytes) + + checks.NoError(t, err, "Write error") + }) + + apiErr := &openai.APIError{} + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) if !errors.As(err, &apiErr) { - t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") + t.Errorf("Did not return APIError: %+v\n", apiErr) + return + } + if apiErr.HTTPStatusCode != http.StatusTooManyRequests { + t.Errorf("Did not return HTTPStatusCode got = %d, want = %d\n", apiErr.HTTPStatusCode, http.StatusTooManyRequests) + return + } + code, ok := apiErr.Code.(string) + if !ok || code != wantCode { + t.Errorf("Did not return Code. got = %v, want = %s\n", apiErr.Code, wantCode) + return + } + if apiErr.Message != wantMessage { + t.Errorf("Did not return Message. got = %s, want = %s\n", apiErr.Message, wantMessage) + return + } +} + +func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + var dataBytes []byte + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"3","object":"completion","created":1598069256,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + StreamOptions: &openai.StreamOptions{ + IncludeUsage: true, + }, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "3", + Object: "completion", + Created: 1598069256, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{}, + Usage: &openai.Usage{ + PromptTokens: 1, + CompletionTokens: 1, + TotalTokens: 2, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + + checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished") + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) } - t.Logf("%+v\n", apiErr) } // Helper funcs. -func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { +func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -268,10 +780,236 @@ func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { return false } } + if r1.Usage != nil || r2.Usage != nil { + if r1.Usage == nil || r2.Usage == nil { + return false + } + if r1.Usage.PromptTokens != r2.Usage.PromptTokens || r1.Usage.CompletionTokens != r2.Usage.CompletionTokens || + r1.Usage.TotalTokens != r2.Usage.TotalTokens { + return false + } + } return true } -func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { +func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxCompletionTokens: 2000, + Model: openai.O3Mini20250131, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " from", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " O3Mini", + }, + }, + }, + }, + { + ID: "5", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err) + } +} + +func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err) + } +} + +func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O4Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err) + } +} + +func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false } diff --git a/chat_test.go b/chat_test.go index ce302a69f..236cff736 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1,12 +1,9 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -14,69 +11,869 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/sashabaranov/go-openai/jsonschema" ) +const ( + xCustomHeader = "X-CUSTOM-HEADER" + xCustomHeaderValue = "test" +) + +var rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", +} + func TestChatCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletion(ctx, req) msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) - checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg) + checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) +} + +func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "o1-preview_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Preview, + }, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, + }, + { + name: "o1-mini_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Mini, + }, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O1Preview, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O3Mini, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestGPT5ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.GPT5, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Nano, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5ChatLatest, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestChatRequestOmitEmpty(t *testing.T) { + data, err := json.Marshal(openai.ChatCompletionRequest{ + // We set model b/c it's required, so omitempty doesn't make sense + Model: "gpt-4", + }) + checks.NoError(t, err) + + // messages is also required so isn't omitted + const expected = `{"model":"gpt-4","messages":null}` + if string(data) != expected { + t.Errorf("expected JSON with all empty fields to be %v but was %v", expected, string(data)) + } } func TestChatCompletionsWithStream(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ Stream: true, } _, err := client.CreateChatCompletion(ctx, req) - checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error") + checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error") } // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletions(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestO1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O1Preview, + MaxCompletionTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + +func TestO3ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O3Mini, + MaxCompletionTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} - req := ChatCompletionRequest{ +func TestDeepseekR1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: "deepseek-reasoner", + MaxCompletionTokens: 100, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + a := resp.Header().Get(xCustomHeader) + _ = a + if resp.Header().Get(xCustomHeader) != xCustomHeaderValue { + t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue) } - _, err = client.CreateChatCompletion(ctx, req) +} + +// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) checks.NoError(t, err, "CreateChatCompletion error") + + headers := resp.GetRateLimitHeaders() + resetRequests := headers.ResetRequests.String() + if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] { + t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"]) + } + resetRequestsTime := headers.ResetRequests.Time() + if resetRequestsTime.Before(time.Now()) { + t.Errorf("unexpected reset requests: %v", resetRequestsTime) + } + + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + +// TestChatCompletionsFunctions tests including a function call. +func TestChatCompletionsFunctions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + t.Run("bytes", func(t *testing.T) { + //nolint:lll + msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "test", + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("struct", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "test", + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("JSONSchemaDefinition", func(t *testing.T) { + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "test", + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "count": { + Type: jsonschema.Number, + Description: "total number of words in sentence", + }, + "words": { + Type: jsonschema.Array, + Description: "list of words in sentence", + Items: &jsonschema.Definition{ + Type: jsonschema.String, + }, + }, + "enumTest": { + Type: jsonschema.String, + Enum: []string{"hello", "world"}, + }, + }, + }, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) { + // this is a compatibility check + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefine{{ + Name: "test", + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "count": { + Type: jsonschema.Number, + Description: "total number of words in sentence", + }, + "words": { + Type: jsonschema.Array, + Description: "list of words in sentence", + Items: &jsonschema.Definition{ + Type: jsonschema.String, + }, + }, + "enumTest": { + Type: jsonschema.String, + Enum: []string{"hello", "world"}, + }, + }, + }, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("StructuredOutputs", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "test", + Strict: true, + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) +} + +func TestAzureChatCompletions(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) + + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateAzureChatCompletion error") +} + +func TestMultipartChatCompletions(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) + + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "Hello!", + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: "URL", + Detail: openai.ImageURLDetailLow, + }, + }, + }, + }, + }, + }) + checks.NoError(t, err, "CreateAzureChatCompletion error") +} + +func TestMultipartChatMessageSerialization(t *testing.T) { + jsonText := `[{"role":"system","content":"system-message"},` + + `{"role":"user","content":[{"type":"text","text":"nice-text"},` + + `{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]` + + var msgs []openai.ChatCompletionMessage + err := json.Unmarshal([]byte(jsonText), &msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + if len(msgs) != 2 { + t.Errorf("unexpected number of messages") + } + if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil { + t.Errorf("invalid user message: %v", msgs[0]) + } + if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 { + t.Errorf("invalid user message") + } + parts := msgs[1].MultiContent + if parts[0].Type != "text" || parts[0].Text != "nice-text" { + t.Errorf("invalid text part: %v", parts[0]) + } + if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" { + t.Errorf("invalid image_url part") + } + + s, err := json.Marshal(msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + res := strings.ReplaceAll(string(s), " ", "") + if res != jsonText { + t.Fatalf("invalid message: %s", string(s)) + } + + invalidMsg := []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "some-text", + MultiContent: []openai.ChatMessagePart{ + { + Type: "text", + Text: "nice-text", + }, + }, + }, + } + _, err = json.Marshal(invalidMsg) + if !errors.Is(err, openai.ErrContentFieldsMisused) { + t.Fatalf("Expected error: %s", err) + } + + err = json.Unmarshal([]byte(`["not-a-message"]`), &msgs) + if err == nil { + t.Fatalf("Expected error") + } + + emptyMultiContentMsg := openai.ChatCompletionMessage{ + Role: "user", + MultiContent: []openai.ChatMessagePart{}, + } + s, err = json.Marshal(emptyMultiContentMsg) + if err != nil { + t.Fatalf("Unexpected error") + } + res = strings.ReplaceAll(string(s), " ", "") + if res != `{"role":"user"}` { + t.Fatalf("invalid message: %s", string(s)) + } } // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. @@ -88,12 +885,12 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq ChatCompletionRequest + var completionReq openai.ChatCompletionRequest if completionReq, err = getChatCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ChatCompletionResponse{ + res := openai.ChatCompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -103,40 +900,308 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < completionReq.N; i++ { + n := completionReq.N + if n == 0 { + n = 1 + } + for i := 0; i < n; i++ { + // if there are functions, include them + if len(completionReq.Functions) > 0 { + var fcb []byte + b := completionReq.Functions[0].Parameters + fcb, err = json.Marshal(b) + if err != nil { + http.Error(w, "could not marshal function parameters", http.StatusInternalServerError) + return + } + + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleFunction, + // this is valid json so it should be fine + FunctionCall: &openai.FunctionCall{ + Name: completionReq.Functions[0].Name, + Arguments: string(fcb), + }, + }, + Index: i, + }) + continue + } // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleAssistant, + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, Content: completionStr, }, Index: i, }) } - inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N - res.Usage = Usage{ + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n + res.Usage = openai.Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + fmt.Fprintln(w, string(resBytes)) +} + +func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq openai.ChatCompletionRequest + if completionReq, err = getChatCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + n := completionReq.N + if n == 0 { + n = 1 + } + if completionReq.MaxCompletionTokens == 0 { + completionReq.MaxCompletionTokens = 1000 + } + for i := 0; i < n; i++ { + reasoningContent := "User says hello! And I need to reply" + completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent)) + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + ReasoningContent: reasoningContent, + Content: completionStr, + }, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } fmt.Fprintln(w, string(resBytes)) } // getChatCompletionBody Returns the body of the request to create a completion. -func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { - completion := ChatCompletionRequest{} +func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { + completion := openai.ChatCompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } return completion, nil } + +func TestFinishReason(t *testing.T) { + c := &openai.ChatCompletionChoice{ + FinishReason: openai.FinishReasonNull, + } + resBytes, _ := json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + c.FinishReason = "" + + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + otherReasons := []openai.FinishReason{ + openai.FinishReasonStop, + openai.FinishReasonLength, + openai.FinishReasonFunctionCall, + openai.FinishReasonContentFilter, + } + for _, r := range otherReasons { + c.FinishReason = r + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), fmt.Sprintf(`"finish_reason":"%s"`, r)) { + t.Errorf("%s should be quoted", r) + } + } +} + +func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": null + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`[123,456]`), + }, + true, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": 123456 + }`), + }, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var r openai.ChatCompletionResponseFormatJSONSchema + err := r.UnmarshalJSON(tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) { + type args struct { + bs []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{bs: []byte(`{ + "model": "llama3-1b", + "messages": [ + { "role": "system", "content": "You are a helpful math tutor." }, + { "role": "user", "content": "solve 8x + 31 = 2" } + ], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + } + } +}`)}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var m openai.ChatCompletionRequest + err := json.Unmarshal(tt.args.bs, &m) + if err != nil { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/client.go b/client.go index 0f8aa41ba..413b8db03 100644 --- a/client.go +++ b/client.go @@ -1,20 +1,48 @@ package openai import ( + "bufio" "context" "encoding/json" "fmt" "io" "net/http" + "net/url" "strings" + + utils "github.com/sashabaranov/go-openai/internal" ) // Client is OpenAI GPT-3 API client. type Client struct { config ClientConfig - requestBuilder requestBuilder - createFormBuilder func(io.Writer) formBuilder + requestBuilder utils.RequestBuilder + createFormBuilder func(io.Writer) utils.FormBuilder +} + +type Response interface { + SetHeader(http.Header) +} + +type httpHeader http.Header + +func (h *httpHeader) SetHeader(header http.Header) { + *h = httpHeader(header) +} + +func (h *httpHeader) Header() http.Header { + return http.Header(*h) +} + +func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { + return newRateLimitHeaders(h.Header()) +} + +type RawResponse struct { + io.ReadCloser + + httpHeader } // NewClient creates new OpenAI API client. @@ -27,9 +55,9 @@ func NewClient(authToken string) *Client { func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, - requestBuilder: newRequestBuilder(), - createFormBuilder: func(body io.Writer) formBuilder { - return newFormBuilder(body) + requestBuilder: utils.NewRequestBuilder(), + createFormBuilder: func(body io.Writer) utils.FormBuilder { + return utils.NewFormBuilder(body) }, } } @@ -43,25 +71,70 @@ func NewOrgClient(authToken, org string) *Client { return NewClientWithConfig(config) } -func (c *Client) sendRequest(req *http.Request, v any) error { - req.Header.Set("Accept", "application/json; charset=utf-8") - // Azure API Key authentication - if c.config.APIType == APITypeAzure { - req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) +type requestOptions struct { + body any + header http.Header +} + +type requestOption func(*requestOptions) + +func withBody(body any) requestOption { + return func(args *requestOptions) { + args.body = body + } +} + +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + // Assert that args.body is a map[string]any. + bodyMap, ok := args.body.(map[string]any) + if ok { + // If it's a map[string]any then only add extraBody + // fields to args.body otherwise keep only fields in request struct. + for key, value := range extraBody { + bodyMap[key] = value + } + } + } +} + +func withContentType(contentType string) requestOption { + return func(args *requestOptions) { + args.header.Set("Content-Type", contentType) + } +} + +func withBetaAssistantVersion(version string) requestOption { + return func(args *requestOptions) { + args.header.Set("OpenAI-Beta", fmt.Sprintf("assistants=%s", version)) } +} + +func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) { + // Default Options + args := &requestOptions{ + body: nil, + header: make(http.Header), + } + for _, setter := range setters { + setter(args) + } + req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header) + if err != nil { + return nil, err + } + c.setCommonHeaders(req) + return req, nil +} + +func (c *Client) sendRequest(req *http.Request, v Response) error { + req.Header.Set("Accept", "application/json") // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data contentType := req.Header.Get("Content-Type") if contentType == "" { - req.Header.Set("Content-Type", "application/json; charset=utf-8") - } - - if len(c.config.OrgID) > 0 { - req.Header.Set("OpenAI-Organization", c.config.OrgID) + req.Header.Set("Content-Type", "application/json") } res, err := c.config.HTTPClient.Do(req) @@ -71,22 +144,95 @@ func (c *Client) sendRequest(req *http.Request, v any) error { defer res.Body.Close() - if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { + if v != nil { + v.SetHeader(res.Header) + } + + if isFailureStatusCode(res) { return c.handleErrorResp(res) } return decodeResponse(res.Body, v) } +func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err error) { + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body should be closed by outer function + if err != nil { + return + } + + if isFailureStatusCode(resp) { + err = c.handleErrorResp(resp) + return + } + + response.SetHeader(resp.Header) + response.ReadCloser = resp.Body + return +} + +func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + if err != nil { + return new(streamReader[T]), err + } + if isFailureStatusCode(resp) { + return new(streamReader[T]), client.handleErrorResp(resp) + } + return &streamReader[T]{ + emptyMessagesLimit: client.config.EmptyMessagesLimit, + reader: bufio.NewReader(resp.Body), + response: resp, + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + httpHeader: httpHeader(resp.Header), + }, nil +} + +func (c *Client) setCommonHeaders(req *http.Request) { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication + switch c.config.APIType { + case APITypeAzure, APITypeCloudflareAzure: + // Azure API Key authentication + req.Header.Set(AzureAPIKeyHeader, c.config.authToken) + case APITypeAnthropic: + // https://docs.anthropic.com/en/api/versioning + req.Header.Set("anthropic-version", c.config.APIVersion) + case APITypeOpenAI, APITypeAzureAD: + fallthrough + default: + if c.config.authToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } + } + + if c.config.OrgID != "" { + req.Header.Set("OpenAI-Organization", c.config.OrgID) + } +} + +func isFailureStatusCode(resp *http.Response) bool { + return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest +} + func decodeResponse(body io.Reader, v any) error { if v == nil { return nil } - if result, ok := v.(*string); ok { - return decodeString(body, result) + switch o := v.(type) { + case *string: + return decodeString(body, o) + case *audioTextResponse: + return decodeString(body, &o.Text) + default: + return json.NewDecoder(body).Decode(v) } - return json.NewDecoder(body).Decode(v) } func decodeString(body io.Reader, output *string) error { @@ -98,60 +244,81 @@ func decodeString(body io.Reader, output *string) error { return nil } -func (c *Client) fullURL(suffix string) string { - // /openai/deployments/{engine}/chat/completions?api-version={api_version} - if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 - // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if strings.Contains(suffix, "/models") { - return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) - } - return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) +type fullURLOptions struct { + model string +} + +type fullURLOption func(*fullURLOptions) + +func withModel(model string) fullURLOption { + return func(args *fullURLOptions) { + args.model = model } +} - // c.config.APIType == APITypeOpenAI || c.config.APIType == "" - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) +var azureDeploymentsEndpoints = []string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/audio/speech", + "/images/generations", } -func (c *Client) newStreamRequest( - ctx context.Context, - method string, - urlSuffix string, - body any) (*http.Request, error) { - req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body) - if err != nil { - return nil, err +// fullURL returns full URL for request. +func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { + baseURL := strings.TrimRight(c.config.BaseURL, "/") + args := fullURLOptions{} + for _, setter := range setters { + setter(&args) } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { + baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model) + } - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication - // Azure API Key authentication - if c.config.APIType == APITypeAzure { - req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + if c.config.APIVersion != "" { + suffix = c.suffixWithAPIVersion(suffix) } - if c.config.OrgID != "" { - req.Header.Set("OpenAI-Organization", c.config.OrgID) + return fmt.Sprintf("%s%s", baseURL, suffix) +} + +func (c *Client) suffixWithAPIVersion(suffix string) string { + parsedSuffix, err := url.Parse(suffix) + if err != nil { + panic("failed to parse url suffix") } - return req, nil + query := parsedSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode()) +} + +func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) { + baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix) + if containsSubstr(azureDeploymentsEndpoints, suffix) { + azureDeploymentName := c.config.GetAzureDeploymentByModel(model) + if azureDeploymentName == "" { + azureDeploymentName = "UNKNOWN" + } + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) + } + return baseURL } func (c *Client) handleErrorResp(resp *http.Response) error { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } var errRes ErrorResponse - err := json.NewDecoder(resp.Body).Decode(&errRes) + err = json.Unmarshal(body, &errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ + HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, + Body: body, } if errRes.Error != nil { reqErr.Err = errRes.Error @@ -159,6 +326,16 @@ func (c *Client) handleErrorResp(resp *http.Response) error { return reqErr } + errRes.Error.HTTPStatus = resp.Status errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } + +func containsSubstr(s []string, e string) bool { + for _, v := range s { + if strings.Contains(e, v) { + return true + } + } + return false +} diff --git a/client_test.go b/client_test.go index e30fa399b..321971445 100644 --- a/client_test.go +++ b/client_test.go @@ -2,13 +2,26 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" "errors" "fmt" "io" "net/http" + "reflect" "testing" + + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) +var errTestRequestBuilderFailed = errors.New("test request builder failed") + +type failingRequestBuilder struct{} + +func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) { + return nil, errTestRequestBuilderFailed +} + func TestClient(t *testing.T) { const mockToken = "mock token" client := NewClient(mockToken) @@ -26,41 +39,101 @@ func TestClient(t *testing.T) { } } +func TestSetCommonHeadersAnthropic(t *testing.T) { + config := DefaultAnthropicConfig("mock-token", "") + client := NewClientWithConfig(config) + req, err := http.NewRequest("GET", "http://example.com", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client.setCommonHeaders(req) + + if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion { + t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got) + } +} + func TestDecodeResponse(t *testing.T) { stringInput := "" testCases := []struct { - name string - value interface{} - body io.Reader + name string + value interface{} + expected interface{} + body io.Reader + hasError bool }{ { - name: "nil input", - value: nil, - body: bytes.NewReader([]byte("")), + name: "nil input", + value: nil, + body: bytes.NewReader([]byte("")), + expected: nil, }, { - name: "string input", - value: &stringInput, - body: bytes.NewReader([]byte("test")), + name: "string input", + value: &stringInput, + body: bytes.NewReader([]byte("test")), + expected: "test", }, { name: "map input", value: &map[string]interface{}{}, body: bytes.NewReader([]byte(`{"test": "test"}`)), + expected: map[string]interface{}{ + "test": "test", + }, + }, + { + name: "reader return error", + value: &stringInput, + body: &errorReader{err: errors.New("dummy")}, + hasError: true, }, + { + name: "audio text input", + value: &audioTextResponse{}, + body: bytes.NewReader([]byte("test")), + expected: audioTextResponse{ + Text: "test", + }, + }, + } + + assertEqual := func(t *testing.T, expected, actual interface{}) { + t.Helper() + if expected == actual { + return + } + v := reflect.ValueOf(actual).Elem().Interface() + if !reflect.DeepEqual(v, expected) { + t.Fatalf("Unexpected value: %v, expected: %v", v, expected) + } } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := decodeResponse(tc.body, tc.value) + if tc.hasError { + checks.HasError(t, err, "Unexpected nil error") + return + } if err != nil { - t.Errorf("Unexpected error: %v", err) + t.Fatalf("Unexpected error: %v", err) } + assertEqual(t, tc.expected, tc.value) }) } } +type errorReader struct { + err error +} + +func (e *errorReader) Read(_ []byte) (n int, err error) { + return 0, e.err +} + func TestHandleErrorResp(t *testing.T) { // var errRes *ErrorResponse var errRes ErrorResponse @@ -76,14 +149,17 @@ func TestHandleErrorResp(t *testing.T) { client := NewClient(mockToken) testCases := []struct { - name string - httpCode int - body io.Reader - expected string + name string + httpCode int + httpStatus string + contentType string + body io.Reader + expected string }{ { - name: "401 Invalid Authentication", - httpCode: http.StatusUnauthorized, + name: "401 Invalid Authentication", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -94,11 +170,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: You didn't provide an API key. ....", + expected: "error, status code: 401, status: , message: You didn't provide an API key. ....", }, { - name: "401 Azure Access Denied", - httpCode: http.StatusUnauthorized, + name: "401 Azure Access Denied", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -107,11 +184,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.", + expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.", }, { - name: "503 Model Overloaded", - httpCode: http.StatusServiceUnavailable, + name: "503 Model Overloaded", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{ @@ -121,13 +199,58 @@ func TestHandleErrorResp(t *testing.T) { "code":null } }`)), - expected: "error, status code: 503, message: That model...", + expected: "error, status code: 503, status: , message: That model...", + }, + { + name: "503 no message (Unknown response)", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", + body: bytes.NewReader([]byte(` + { + "error":{} + }`)), + expected: `error, status code: 503, status: , message: , body: + { + "error":{} + }`, + }, + { + name: "413 Request Entity Too Large", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: bytes.NewReader([]byte(` + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `)), + expected: `error, status code: 413, status: , message: invalid character '<' looking for beginning of value, body: + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `, + }, + { + name: "errorReader", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: &errorReader{err: errors.New("errorReader")}, + expected: "error, reading response body: errorReader", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - testCase := &http.Response{} + testCase := &http.Response{ + Header: map[string][]string{ + "Content-Type": {tc.contentType}, + }, + } testCase.StatusCode = tc.httpCode testCase.Body = io.NopCloser(tc.body) err := client.handleErrorResp(testCase) @@ -136,11 +259,329 @@ func TestHandleErrorResp(t *testing.T) { t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected) t.Fail() } + }) + } +} - e := &APIError{} - if !errors.As(err, &e) { - t.Errorf("(%s) Expected error to be of type APIError", tc.name) - t.Fail() +func TestClientReturnsRequestBuilderErrors(t *testing.T) { + config := DefaultConfig(test.GetTestToken()) + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + ctx := context.Background() + + type TestCase struct { + Name string + TestFunc func() (any, error) + } + + testCases := []TestCase{ + {"CreateCompletion", func() (any, error) { + return client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) + }}, + {"CreateCompletionStream", func() (any, error) { + return client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) + }}, + {"CreateChatCompletion", func() (any, error) { + return client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + }}, + {"CreateChatCompletionStream", func() (any, error) { + return client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + }}, + {"CreateFineTune", func() (any, error) { + return client.CreateFineTune(ctx, FineTuneRequest{}) + }}, + {"ListFineTunes", func() (any, error) { + return client.ListFineTunes(ctx) + }}, + {"CancelFineTune", func() (any, error) { + return client.CancelFineTune(ctx, "") + }}, + {"GetFineTune", func() (any, error) { + return client.GetFineTune(ctx, "") + }}, + {"DeleteFineTune", func() (any, error) { + return client.DeleteFineTune(ctx, "") + }}, + {"ListFineTuneEvents", func() (any, error) { + return client.ListFineTuneEvents(ctx, "") + }}, + {"CreateFineTuningJob", func() (any, error) { + return client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + }}, + {"CancelFineTuningJob", func() (any, error) { + return client.CancelFineTuningJob(ctx, "") + }}, + {"RetrieveFineTuningJob", func() (any, error) { + return client.RetrieveFineTuningJob(ctx, "") + }}, + {"ListFineTuningJobEvents", func() (any, error) { + return client.ListFineTuningJobEvents(ctx, "") + }}, + {"Moderations", func() (any, error) { + return client.Moderations(ctx, ModerationRequest{}) + }}, + {"Edits", func() (any, error) { + return client.Edits(ctx, EditsRequest{}) + }}, + {"CreateEmbeddings", func() (any, error) { + return client.CreateEmbeddings(ctx, EmbeddingRequest{}) + }}, + {"CreateImage", func() (any, error) { + return client.CreateImage(ctx, ImageRequest{}) + }}, + {"CreateFileBytes", func() (any, error) { + return client.CreateFileBytes(ctx, FileBytesRequest{}) + }}, + {"DeleteFile", func() (any, error) { + return nil, client.DeleteFile(ctx, "") + }}, + {"GetFile", func() (any, error) { + return client.GetFile(ctx, "") + }}, + {"GetFileContent", func() (any, error) { + return client.GetFileContent(ctx, "") + }}, + {"ListFiles", func() (any, error) { + return client.ListFiles(ctx) + }}, + {"ListEngines", func() (any, error) { + return client.ListEngines(ctx) + }}, + {"GetEngine", func() (any, error) { + return client.GetEngine(ctx, "") + }}, + {"ListModels", func() (any, error) { + return client.ListModels(ctx) + }}, + {"GetModel", func() (any, error) { + return client.GetModel(ctx, "text-davinci-003") + }}, + {"DeleteFineTuneModel", func() (any, error) { + return client.DeleteFineTuneModel(ctx, "") + }}, + {"CreateAssistant", func() (any, error) { + return client.CreateAssistant(ctx, AssistantRequest{}) + }}, + {"RetrieveAssistant", func() (any, error) { + return client.RetrieveAssistant(ctx, "") + }}, + {"ModifyAssistant", func() (any, error) { + return client.ModifyAssistant(ctx, "", AssistantRequest{}) + }}, + {"DeleteAssistant", func() (any, error) { + return client.DeleteAssistant(ctx, "") + }}, + {"ListAssistants", func() (any, error) { + return client.ListAssistants(ctx, nil, nil, nil, nil) + }}, + {"CreateAssistantFile", func() (any, error) { + return client.CreateAssistantFile(ctx, "", AssistantFileRequest{}) + }}, + {"ListAssistantFiles", func() (any, error) { + return client.ListAssistantFiles(ctx, "", nil, nil, nil, nil) + }}, + {"RetrieveAssistantFile", func() (any, error) { + return client.RetrieveAssistantFile(ctx, "", "") + }}, + {"DeleteAssistantFile", func() (any, error) { + return nil, client.DeleteAssistantFile(ctx, "", "") + }}, + {"CreateMessage", func() (any, error) { + return client.CreateMessage(ctx, "", MessageRequest{}) + }}, + {"ListMessage", func() (any, error) { + return client.ListMessage(ctx, "", nil, nil, nil, nil, nil) + }}, + {"RetrieveMessage", func() (any, error) { + return client.RetrieveMessage(ctx, "", "") + }}, + {"ModifyMessage", func() (any, error) { + return client.ModifyMessage(ctx, "", "", nil) + }}, + {"DeleteMessage", func() (any, error) { + return client.DeleteMessage(ctx, "", "") + }}, + {"RetrieveMessageFile", func() (any, error) { + return client.RetrieveMessageFile(ctx, "", "", "") + }}, + {"ListMessageFiles", func() (any, error) { + return client.ListMessageFiles(ctx, "", "") + }}, + {"CreateThread", func() (any, error) { + return client.CreateThread(ctx, ThreadRequest{}) + }}, + {"RetrieveThread", func() (any, error) { + return client.RetrieveThread(ctx, "") + }}, + {"ModifyThread", func() (any, error) { + return client.ModifyThread(ctx, "", ModifyThreadRequest{}) + }}, + {"DeleteThread", func() (any, error) { + return client.DeleteThread(ctx, "") + }}, + {"CreateRun", func() (any, error) { + return client.CreateRun(ctx, "", RunRequest{}) + }}, + {"RetrieveRun", func() (any, error) { + return client.RetrieveRun(ctx, "", "") + }}, + {"ModifyRun", func() (any, error) { + return client.ModifyRun(ctx, "", "", RunModifyRequest{}) + }}, + {"ListRuns", func() (any, error) { + return client.ListRuns(ctx, "", Pagination{}) + }}, + {"SubmitToolOutputs", func() (any, error) { + return client.SubmitToolOutputs(ctx, "", "", SubmitToolOutputsRequest{}) + }}, + {"CancelRun", func() (any, error) { + return client.CancelRun(ctx, "", "") + }}, + {"CreateThreadAndRun", func() (any, error) { + return client.CreateThreadAndRun(ctx, CreateThreadAndRunRequest{}) + }}, + {"RetrieveRunStep", func() (any, error) { + return client.RetrieveRunStep(ctx, "", "", "") + }}, + {"ListRunSteps", func() (any, error) { + return client.ListRunSteps(ctx, "", "", Pagination{}) + }}, + {"CreateSpeech", func() (any, error) { + return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) + }}, + {"CreateBatch", func() (any, error) { + return client.CreateBatch(ctx, CreateBatchRequest{}) + }}, + {"CreateBatchWithUploadFile", func() (any, error) { + return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{}) + }}, + {"RetrieveBatch", func() (any, error) { + return client.RetrieveBatch(ctx, "") + }}, + {"CancelBatch", func() (any, error) { return client.CancelBatch(ctx, "") }}, + {"ListBatch", func() (any, error) { return client.ListBatch(ctx, nil, nil) }}, + } + + for _, testCase := range testCases { + _, err := testCase.TestFunc() + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err) + } + } +} + +func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { + config := DefaultConfig(test.GetTestToken()) + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + ctx := context.Background() + _, err := client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) + if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) + if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } +} + +func TestClient_suffixWithAPIVersion(t *testing.T) { + type fields struct { + apiVersion string + } + type args struct { + suffix string + } + tests := []struct { + name string + fields fields + args args + want string + wantPanic string + }{ + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants"}, + "/assistants?api-version=2023-05", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "123:assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "failed to parse url suffix", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + config: ClientConfig{APIVersion: tt.fields.apiVersion}, + } + defer func() { + if r := recover(); r != nil { + // Check if the panic message matches the expected panic message + if rStr, ok := r.(string); ok { + if rStr != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", rStr, tt.wantPanic) + } + } else { + // If the panic is not a string, log it + t.Errorf("suffixWithAPIVersion() panicked with non-string value: %v", r) + } + } + }() + if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want { + t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_baseURLWithAzureDeployment(t *testing.T) { + type args struct { + baseURL string + suffix string + model string + } + tests := []struct { + name string + args args + wantNewBaseURL string + }{ + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai/deployments/gpt-4o-mini", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""}, + "https://test.openai.azure.com/openai/deployments/UNKNOWN", + }, + } + client := NewClient("") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotNewBaseURL := client.baseURLWithAzureDeployment( + tt.args.baseURL, + tt.args.suffix, + tt.args.model, + ); gotNewBaseURL != tt.wantNewBaseURL { + t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL) } }) } diff --git a/common.go b/common.go index cbfda4e3c..d1936d656 100644 --- a/common.go +++ b/common.go @@ -4,7 +4,23 @@ package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` +} + +// CompletionTokensDetails Breakdown of tokens used in a completion. +type CompletionTokensDetails struct { + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +// PromptTokensDetails Breakdown of tokens used in the prompt. +type PromptTokensDetails struct { + AudioTokens int `json:"audio_tokens"` + CachedTokens int `json:"cached_tokens"` } diff --git a/completion.go b/completion.go index 5eec88c29..27d69f587 100644 --- a/completion.go +++ b/completion.go @@ -2,39 +2,92 @@ package openai import ( "context" - "errors" "net/http" ) -var ( - ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll - ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll - ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll -) - // GPT3 Defines the models provided by OpenAI to use when generating // completions from OpenAI. // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( + O1Mini = "o1-mini" + O1Mini20240912 = "o1-mini-2024-09-12" + O1Preview = "o1-preview" + O1Preview20240912 = "o1-preview-2024-09-12" + O1 = "o1" + O120241217 = "o1-2024-12-17" + O3 = "o3" + O320250416 = "o3-2025-04-16" + O3Mini = "o3-mini" + O3Mini20250131 = "o3-mini-2025-01-31" + O4Mini = "o4-mini" + O4Mini20250416 = "o4-mini-2025-04-16" + GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" + GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4o = "gpt-4o" + GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4o20241120 = "gpt-4o-2024-11-20" + GPT4oLatest = "chatgpt-4o-latest" + GPT4oMini = "gpt-4o-mini" + GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" + GPT4Turbo0125 = "gpt-4-0125-preview" + GPT4Turbo1106 = "gpt-4-1106-preview" + GPT4TurboPreview = "gpt-4-turbo-preview" + GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT4Dot1 = "gpt-4.1" + GPT4Dot120250414 = "gpt-4.1-2025-04-14" + GPT4Dot1Mini = "gpt-4.1-mini" + GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14" + GPT4Dot1Nano = "gpt-4.1-nano" + GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14" + GPT4Dot5Preview = "gpt-4.5-preview" + GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" + GPT5 = "gpt-5" + GPT5Mini = "gpt-5-mini" + GPT5Nano = "gpt-5-nano" + GPT5ChatLatest = "gpt-5-chat-latest" + GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" + GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" + GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" + GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" GPT3Dot5Turbo = "gpt-3.5-turbo" - GPT3TextDavinci003 = "text-davinci-003" - GPT3TextDavinci002 = "text-davinci-002" - GPT3TextCurie001 = "text-curie-001" - GPT3TextBabbage001 = "text-babbage-001" - GPT3TextAda001 = "text-ada-001" - GPT3TextDavinci001 = "text-davinci-001" + GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. + GPT3TextDavinci003 = "text-davinci-003" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. + GPT3TextDavinci002 = "text-davinci-002" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. + GPT3TextCurie001 = "text-curie-001" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. + GPT3TextBabbage001 = "text-babbage-001" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. + GPT3TextAda001 = "text-ada-001" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. + GPT3TextDavinci001 = "text-davinci-001" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3CurieInstructBeta = "curie-instruct-beta" - GPT3Curie = "curie" - GPT3Ada = "ada" - GPT3Babbage = "babbage" + // Deprecated: Model is shutdown. Use davinci-002 instead. + GPT3Davinci = "davinci" + GPT3Davinci002 = "davinci-002" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. + GPT3CurieInstructBeta = "curie-instruct-beta" + GPT3Curie = "curie" + GPT3Curie002 = "curie-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. @@ -48,14 +101,57 @@ const ( var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { - GPT3Dot5Turbo: true, - GPT3Dot5Turbo0301: true, - GPT4: true, - GPT40314: true, - GPT432K: true, - GPT432K0314: true, + O1Mini: true, + O1Mini20240912: true, + O1Preview: true, + O1Preview20240912: true, + O3Mini: true, + O3Mini20250131: true, + O4Mini: true, + O4Mini20250416: true, + O3: true, + O320250416: true, + GPT3Dot5Turbo: true, + GPT3Dot5Turbo0301: true, + GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo1106: true, + GPT3Dot5Turbo0125: true, + GPT3Dot5Turbo16K: true, + GPT3Dot5Turbo16K0613: true, + GPT4: true, + GPT4Dot5Preview: true, + GPT4Dot5Preview20250227: true, + GPT4o: true, + GPT4o20240513: true, + GPT4o20240806: true, + GPT4o20241120: true, + GPT4oLatest: true, + GPT4oMini: true, + GPT4oMini20240718: true, + GPT4TurboPreview: true, + GPT4VisionPreview: true, + GPT4Turbo1106: true, + GPT4Turbo0125: true, + GPT4Turbo: true, + GPT4Turbo20240409: true, + GPT40314: true, + GPT40613: true, + GPT432K: true, + GPT432K0314: true, + GPT432K0613: true, + O1: true, + GPT4Dot1: true, + GPT4Dot120250414: true, + GPT4Dot1Mini: true, + GPT4Dot1Mini20250414: true, + GPT4Dot1Nano: true, + GPT4Dot1Nano20250414: true, + GPT5: true, + GPT5Mini: true, + GPT5Nano: true, + GPT5ChatLatest: true, }, - "/chat/completions": { + chatCompletionsSuffix: { CodexCodeDavinci002: true, CodexCodeCushman001: true, CodexCodeDavinci001: true, @@ -81,27 +177,54 @@ func checkEndpointSupportsModel(endpoint, model string) bool { func checkPromptType(prompt any) bool { _, isString := prompt.(string) _, isStringSlice := prompt.([]string) - return isString || isStringSlice + if isString || isStringSlice { + return true + } + + // check if it is prompt is []string hidden under []any + slice, isSlice := prompt.([]any) + if !isSlice { + return false + } + + for _, item := range slice { + _, itemIsString := item.(string) + if !itemIsString { + return false + } + } + return true // all items in the slice are string, so it is []string } // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { - Model string `json:"model"` - Prompt any `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model"` + Prompt any `json:"prompt,omitempty"` + BestOf int `json:"best_of,omitempty"` + Echo bool `json:"echo,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. + // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` + // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias + LogitBias map[string]int `json:"logit_bias,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` } // CompletionChoice represents one of possible completions. @@ -127,7 +250,9 @@ type CompletionResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []CompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + Usage *Usage `json:"usage,omitempty"` + + httpHeader } // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well @@ -155,7 +280,12 @@ func (c *Client) CreateCompletion( return } - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/completion_test.go b/completion_test.go index 2e302591a..abfc3007e 100644 --- a/completion_test.go +++ b/completion_test.go @@ -1,10 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" @@ -15,58 +11,119 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletion( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } +// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported. +func TestCompletionsWrongModelO3(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O3, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err) + } +} + +// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported. +func TestCompletionsWrongModelO4Mini(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O4Mini, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err) + } +} + func TestCompletionWithStream(t *testing.T) { - config := DefaultConfig("whatever") - client := NewClientWithConfig(config) + config := openai.DefaultConfig("whatever") + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := CompletionRequest{Stream: true} + req := openai.CompletionRequest{Stream: true} _, err := client.CreateCompletion(ctx, req) - if !errors.Is(err, ErrCompletionStreamNotSupported) { + if !errors.Is(err, openai.ErrCompletionStreamNotSupported) { t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported") } } // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestCompletions(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/completions", handleCompletionEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: "Lorem ipsum", + } + _, err := client.CreateCompletion(context.Background(), req) + checks.NoError(t, err, "CreateCompletion error") +} - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() +// TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts with wrong type. +func TestMultiplePromptsCompletionsWrong(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", 9}, + } + _, err := client.CreateCompletion(context.Background(), req) + if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err) + } +} - req := CompletionRequest{ +// TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts. +func TestMultiplePromptsCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ MaxTokens: 5, Model: "ada", + Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"}, } - req.Prompt = "Lorem ipsum" - _, err = client.CreateCompletion(ctx, req) + _, err := client.CreateCompletion(context.Background(), req) checks.NoError(t, err, "CreateCompletion error") } @@ -79,12 +136,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq CompletionRequest + var completionReq openai.CompletionRequest if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := CompletionResponse{ + res := openai.CompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -94,39 +151,181 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < completionReq.N; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - if completionReq.Echo { - completionStr = completionReq.Prompt.(string) + completionStr + n := completionReq.N + if n == 0 { + n = 1 + } + // Handle different types of prompts: single string or list of strings + prompts := []string{} + switch v := completionReq.Prompt.(type) { + case string: + prompts = append(prompts, v) + case []interface{}: + for _, item := range v { + if str, ok := item.(string); ok { + prompts = append(prompts, str) + } } - res.Choices = append(res.Choices, CompletionChoice{ - Text: completionStr, - Index: i, - }) + default: + http.Error(w, "Invalid prompt type", http.StatusBadRequest) + return } - inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N - res.Usage = Usage{ + + for i := 0; i < n; i++ { + for _, prompt := range prompts { + // Generate a random string of length completionReq.MaxTokens + completionStr := strings.Repeat("a", completionReq.MaxTokens) + if completionReq.Echo { + completionStr = prompt + completionStr + } + + res.Choices = append(res.Choices, openai.CompletionChoice{ + Text: completionStr, + Index: len(res.Choices), + }) + } + } + + inputTokens := 0 + for _, prompt := range prompts { + inputTokens += numTokens(prompt) + } + inputTokens *= n + completionTokens := completionReq.MaxTokens * len(prompts) * n + res.Usage = &openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, } + + // Serialize the response and send it back resBytes, _ = json.Marshal(res) fmt.Fprintln(w, string(resBytes)) } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} +func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { + completion := openai.CompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } return completion, nil } + +// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint. +func TestCompletionWithO1Model(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O1, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err) + } +} + +// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint. +func TestCompletionWithGPT4DotModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4Dot1, + openai.GPT4Dot120250414, + openai.GPT4Dot1Mini, + openai.GPT4Dot1Mini20250414, + openai.GPT4Dot1Nano, + openai.GPT4Dot1Nano20250414, + openai.GPT4Dot5Preview, + openai.GPT4Dot5Preview20250227, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} + +// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint. +func TestCompletionWithGPT4oModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4o, + openai.GPT4o20240513, + openai.GPT4o20240806, + openai.GPT4o20241120, + openai.GPT4oLatest, + openai.GPT4oMini, + openai.GPT4oMini20240718, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} + +// TestCompletionWithGPT5Models Tests that GPT5 models are not supported for completion endpoint. +func TestCompletionWithGPT5Models(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT5, + openai.GPT5Mini, + openai.GPT5Nano, + openai.GPT5ChatLatest, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} diff --git a/config.go b/config.go index c800df15c..4788ba62a 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,7 @@ package openai import ( "net/http" + "regexp" ) const ( @@ -10,39 +11,50 @@ const ( azureAPIPrefix = "openai" azureDeploymentsPrefix = "deployments" + + AnthropicAPIVersion = "2023-06-01" ) type APIType string const ( - APITypeOpenAI APIType = "OPEN_AI" - APITypeAzure APIType = "AZURE" - APITypeAzureAD APIType = "AZURE_AD" + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" + APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" + APITypeAnthropic APIType = "ANTHROPIC" ) const AzureAPIKeyHeader = "api-key" +const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store + +type HTTPDoer interface { + Do(req *http.Request) (*http.Response, error) +} + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string - BaseURL string - OrgID string - APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD - Engine string // required when APIType is APITypeAzure or APITypeAzureAD - - HTTPClient *http.Client + BaseURL string + OrgID string + APIType APIType + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic + AssistantVersion string + AzureModelMapperFunc func(model string) string // replace model to azure deployment name func + HTTPClient HTTPDoer EmptyMessagesLimit uint } func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ - authToken: authToken, - BaseURL: openaiAPIURLv1, - APIType: APITypeOpenAI, - OrgID: "", + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + AssistantVersion: defaultAssistantVersion, + OrgID: "", HTTPClient: &http.Client{}, @@ -50,14 +62,33 @@ func DefaultConfig(authToken string) ClientConfig { } } -func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { +func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { return ClientConfig{ authToken: apiKey, BaseURL: baseURL, OrgID: "", APIType: APITypeAzure, - APIVersion: "2023-03-15-preview", - Engine: engine, + APIVersion: "2023-05-15", + AzureModelMapperFunc: func(model string) string { + return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + }, + + HTTPClient: &http.Client{}, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} + +func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig { + if baseURL == "" { + baseURL = "https://api.anthropic.com/v1" + } + return ClientConfig{ + authToken: apiKey, + BaseURL: baseURL, + OrgID: "", + APIType: APITypeAnthropic, + APIVersion: AnthropicAPIVersion, HTTPClient: &http.Client{}, @@ -68,3 +99,11 @@ func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { func (ClientConfig) String() string { return "" } + +func (c ClientConfig) GetAzureDeploymentByModel(model string) string { + if c.AzureModelMapperFunc != nil { + return c.AzureModelMapperFunc(model) + } + + return model +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 000000000..960230804 --- /dev/null +++ b/config_test.go @@ -0,0 +1,123 @@ +package openai_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai" +) + +func TestGetAzureDeploymentByModel(t *testing.T) { + cases := []struct { + Model string + AzureModelMapperFunc func(model string) string + Expect string + }{ + { + Model: "gpt-3.5-turbo", + Expect: "gpt-35-turbo", + }, + { + Model: "gpt-3.5-turbo-0301", + Expect: "gpt-35-turbo-0301", + }, + { + Model: "text-embedding-ada-002", + Expect: "text-embedding-ada-002", + }, + { + Model: "", + Expect: "", + }, + { + Model: "models", + Expect: "models", + }, + { + Model: "gpt-3.5-turbo", + Expect: "my-gpt35", + AzureModelMapperFunc: func(model string) string { + modelmapper := map[string]string{ + "gpt-3.5-turbo": "my-gpt35", + } + if val, ok := modelmapper[model]; ok { + return val + } + return model + }, + }, + } + + for _, c := range cases { + t.Run(c.Model, func(t *testing.T) { + conf := openai.DefaultAzureConfig("", "https://test.openai.azure.com/") + if c.AzureModelMapperFunc != nil { + conf.AzureModelMapperFunc = c.AzureModelMapperFunc + } + actual := conf.GetAzureDeploymentByModel(c.Model) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + }) + } +} + +func TestDefaultAnthropicConfig(t *testing.T) { + apiKey := "test-key" + baseURL := "https://api.anthropic.com/v1" + + config := openai.DefaultAnthropicConfig(apiKey, baseURL) + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion) + } + + if config.BaseURL != baseURL { + t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL) + } + + if config.EmptyMessagesLimit != 300 { + t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit) + } +} + +func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) { + config := openai.DefaultAnthropicConfig("", "") + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion) + } + + expectedBaseURL := "https://api.anthropic.com/v1" + if config.BaseURL != expectedBaseURL { + t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL) + } +} + +func TestClientConfigString(t *testing.T) { + // String() should always return the constant value + conf := openai.DefaultConfig("dummy-token") + expected := "" + got := conf.String() + if got != expected { + t.Errorf("ClientConfig.String() = %q; want %q", got, expected) + } +} + +func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) { + // On a zero-value or DefaultConfig, AzureModelMapperFunc is nil, + // so GetAzureDeploymentByModel should just return the input model. + conf := openai.DefaultConfig("dummy-token") + model := "some-model" + got := conf.GetAzureDeploymentByModel(model) + if got != model { + t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model) + } +} diff --git a/edits.go b/edits.go index 858a8e537..fe8ecd0c1 100644 --- a/edits.go +++ b/edits.go @@ -2,6 +2,7 @@ package openai import ( "context" + "fmt" "net/http" ) @@ -27,11 +28,22 @@ type EditsResponse struct { Created int64 `json:"created"` Usage Usage `json:"usage"` Choices []EditsChoice `json:"choices"` + + httpHeader } -// Perform an API call to the Edits endpoint. +// Edits Perform an API call to the Edits endpoint. +/* Deprecated: Users of the Edits API and its associated models (e.g., text-davinci-edit-001 or code-davinci-edit-001) +will need to migrate to GPT-3.5 Turbo by January 4, 2024. +You can use CreateChatCompletion or CreateChatCompletionStream instead. +*/ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/edits", withModel(fmt.Sprint(request.Model))), + withBody(request), + ) if err != nil { return } diff --git a/edits_test.go b/edits_test.go index fa6c12825..d2a6db40d 100644 --- a/edits_test.go +++ b/edits_test.go @@ -1,10 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -12,26 +8,19 @@ import ( "net/http" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestEdits Tests the edits endpoint of the API using the mocked server. func TestEdits(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/edits", handleEditEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - // create an edit request model := "ada" - editReq := EditsRequest{ + editReq := openai.EditsRequest{ Model: &model, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + @@ -40,7 +29,7 @@ func TestEdits(t *testing.T) { Instruction: "test instruction", N: 3, } - response, err := client.Edits(ctx, editReq) + response, err := client.Edits(context.Background(), editReq) checks.NoError(t, err, "Edits error") if len(response.Choices) != editReq.N { t.Fatalf("edits does not properly return the correct number of choices") @@ -56,14 +45,14 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var editReq EditsRequest + var editReq openai.EditsRequest editReq, err = getEditBody(r) if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } // create a response - res := EditsResponse{ + res := openai.EditsResponse{ Object: "test-object", Created: time.Now().Unix(), } @@ -73,12 +62,12 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { completionTokens := int(float32(len(editString))/4) * editReq.N for i := 0; i < editReq.N; i++ { // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, EditsChoice{ + res.Choices = append(res.Choices, openai.EditsChoice{ Text: editReq.Input + editString, Index: i, }) } - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -88,16 +77,16 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { } // getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (EditsRequest, error) { - edit := EditsRequest{} +func getEditBody(r *http.Request) (openai.EditsRequest, error) { + edit := openai.EditsRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } err = json.Unmarshal(reqBody, &edit) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } return edit, nil } diff --git a/embeddings.go b/embeddings.go index 2deaccc3a..8593f8b5b 100644 --- a/embeddings.go +++ b/embeddings.go @@ -2,96 +2,43 @@ package openai import ( "context" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "math" "net/http" ) +var ErrVectorLengthMismatch = errors.New("vector length mismatch") + // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. -type EmbeddingModel int - -// String implements the fmt.Stringer interface. -func (e EmbeddingModel) String() string { - return enumToString[e] -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (e EmbeddingModel) MarshalText() ([]byte, error) { - return []byte(e.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// On unrecognized value, it sets |e| to Unknown. -func (e *EmbeddingModel) UnmarshalText(b []byte) error { - if val, ok := stringToEnum[(string(b))]; ok { - *e = val - return nil - } - - *e = Unknown - - return nil -} +type EmbeddingModel string const ( - Unknown EmbeddingModel = iota - AdaSimilarity - BabbageSimilarity - CurieSimilarity - DavinciSimilarity - AdaSearchDocument - AdaSearchQuery - BabbageSearchDocument - BabbageSearchQuery - CurieSearchDocument - CurieSearchQuery - DavinciSearchDocument - DavinciSearchQuery - AdaCodeSearchCode - AdaCodeSearchText - BabbageCodeSearchCode - BabbageCodeSearchText - AdaEmbeddingV2 -) + // Deprecated: The following block is shut down. Use text-embedding-ada-002 instead. + AdaSimilarity EmbeddingModel = "text-similarity-ada-001" + BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" + CurieSimilarity EmbeddingModel = "text-similarity-curie-001" + DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001" + AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001" + AdaSearchQuery EmbeddingModel = "text-search-ada-query-001" + BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001" + BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001" + CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001" + CurieSearchQuery EmbeddingModel = "text-search-curie-query-001" + DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001" + DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001" + AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001" + AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001" + BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001" + BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001" -var enumToString = map[EmbeddingModel]string{ - AdaSimilarity: "text-similarity-ada-001", - BabbageSimilarity: "text-similarity-babbage-001", - CurieSimilarity: "text-similarity-curie-001", - DavinciSimilarity: "text-similarity-davinci-001", - AdaSearchDocument: "text-search-ada-doc-001", - AdaSearchQuery: "text-search-ada-query-001", - BabbageSearchDocument: "text-search-babbage-doc-001", - BabbageSearchQuery: "text-search-babbage-query-001", - CurieSearchDocument: "text-search-curie-doc-001", - CurieSearchQuery: "text-search-curie-query-001", - DavinciSearchDocument: "text-search-davinci-doc-001", - DavinciSearchQuery: "text-search-davinci-query-001", - AdaCodeSearchCode: "code-search-ada-code-001", - AdaCodeSearchText: "code-search-ada-text-001", - BabbageCodeSearchCode: "code-search-babbage-code-001", - BabbageCodeSearchText: "code-search-babbage-text-001", - AdaEmbeddingV2: "text-embedding-ada-002", -} - -var stringToEnum = map[string]EmbeddingModel{ - "text-similarity-ada-001": AdaSimilarity, - "text-similarity-babbage-001": BabbageSimilarity, - "text-similarity-curie-001": CurieSimilarity, - "text-similarity-davinci-001": DavinciSimilarity, - "text-search-ada-doc-001": AdaSearchDocument, - "text-search-ada-query-001": AdaSearchQuery, - "text-search-babbage-doc-001": BabbageSearchDocument, - "text-search-babbage-query-001": BabbageSearchQuery, - "text-search-curie-doc-001": CurieSearchDocument, - "text-search-curie-query-001": CurieSearchQuery, - "text-search-davinci-doc-001": DavinciSearchDocument, - "text-search-davinci-query-001": DavinciSearchQuery, - "code-search-ada-code-001": AdaCodeSearchCode, - "code-search-ada-text-001": AdaCodeSearchText, - "code-search-babbage-code-001": BabbageCodeSearchCode, - "code-search-babbage-text-001": BabbageCodeSearchText, - "text-embedding-ada-002": AdaEmbeddingV2, -} + AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" + SmallEmbedding3 EmbeddingModel = "text-embedding-3-small" + LargeEmbedding3 EmbeddingModel = "text-embedding-3-large" +) // Embedding is a special format of data representation that can be easily utilized by machine // learning models and algorithms. The embedding is an information dense representation of the @@ -105,18 +52,128 @@ type Embedding struct { Index int `json:"index"` } +// DotProduct calculates the dot product of the embedding vector with another +// embedding vector. Both vectors must have the same length; otherwise, an +// ErrVectorLengthMismatch is returned. The method returns the calculated dot +// product as a float32 value. +func (e *Embedding) DotProduct(other *Embedding) (float32, error) { + if len(e.Embedding) != len(other.Embedding) { + return 0, ErrVectorLengthMismatch + } + + var dotProduct float32 + for i := range e.Embedding { + dotProduct += e.Embedding[i] * other.Embedding[i] + } + + return dotProduct, nil +} + // EmbeddingResponse is the response from a Create embeddings request. type EmbeddingResponse struct { Object string `json:"object"` Data []Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader +} + +type base64String string + +func (b base64String) Decode() ([]float32, error) { + decodedData, err := base64.StdEncoding.DecodeString(string(b)) + if err != nil { + return nil, err + } + + const sizeOfFloat32 = 4 + floats := make([]float32, len(decodedData)/sizeOfFloat32) + for i := 0; i < len(floats); i++ { + floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4])) + } + + return floats, nil +} + +// Base64Embedding is a container for base64 encoded embeddings. +type Base64Embedding struct { + Object string `json:"object"` + Embedding base64String `json:"embedding"` + Index int `json:"index"` +} + +// EmbeddingResponseBase64 is the response from a Create embeddings request with base64 encoding format. +type EmbeddingResponseBase64 struct { + Object string `json:"object"` + Data []Base64Embedding `json:"data"` + Model EmbeddingModel `json:"model"` + Usage Usage `json:"usage"` + + httpHeader +} + +// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. +func (r *EmbeddingResponseBase64) ToEmbeddingResponse() (EmbeddingResponse, error) { + data := make([]Embedding, len(r.Data)) + + for i, base64Embedding := range r.Data { + embedding, err := base64Embedding.Embedding.Decode() + if err != nil { + return EmbeddingResponse{}, err + } + + data[i] = Embedding{ + Object: base64Embedding.Object, + Embedding: embedding, + Index: base64Embedding.Index, + } + } + + return EmbeddingResponse{ + Object: r.Object, + Model: r.Model, + Data: data, + Usage: r.Usage, + }, nil } -// EmbeddingRequest is the input to a Create embeddings request. +type EmbeddingRequestConverter interface { + // Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens + Convert() EmbeddingRequest +} + +// EmbeddingEncodingFormat is the format of the embeddings data. +// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. +// If not specified OpenAI will use "float". +type EmbeddingEncodingFormat string + +const ( + EmbeddingEncodingFormatFloat EmbeddingEncodingFormat = "float" + EmbeddingEncodingFormatBase64 EmbeddingEncodingFormat = "base64" +) + type EmbeddingRequest struct { + Input any `json:"input"` + Model EmbeddingModel `json:"model"` + User string `json:"user,omitempty"` + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` +} + +func (r EmbeddingRequest) Convert() EmbeddingRequest { + return r +} + +// EmbeddingRequestStrings is the input to a create embeddings request with a slice of strings. +type EmbeddingRequestStrings struct { // Input is a slice of strings for which you want to generate an Embedding vector. - // Each input must not exceed 2048 tokens in length. + // Each input must not exceed 8192 tokens in length. // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they // have observed inferior results when newlines are present. // E.g. @@ -127,17 +184,114 @@ type EmbeddingRequest struct { Model EmbeddingModel `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` +} + +func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { + return EmbeddingRequest{ + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, + Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, + } +} + +type EmbeddingRequestTokens struct { + // Input is a slice of slices of ints ([][]int) for which you want to generate an Embedding vector. + // Each input must not exceed 8192 tokens in length. + // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they + // have observed inferior results when newlines are present. + // E.g. + // "The food was delicious and the waiter..." + Input [][]int `json:"input"` + // ID of the model to use. You can use the List models API to see all of your available models, + // or see our Model overview for descriptions of them. + Model EmbeddingModel `json:"model"` + // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. + User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` +} + +func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { + return EmbeddingRequest{ + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, + Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, + } } -// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. +// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |body.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create -func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) +// +// Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens +// for embedding groups of text already converted to tokens. +func (c *Client) CreateEmbeddings( + ctx context.Context, + conv EmbeddingRequestConverter, +) (res EmbeddingResponse, err error) { + baseReq := conv.Convert() + + // The body map is used to dynamically construct the request payload for the embedding API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // based on their presence, avoiding unnecessary or empty fields in the request. + extraBody := baseReq.ExtraBody + baseReq.ExtraBody = nil + + // Serialize baseReq to JSON + jsonData, err := json.Marshal(baseReq) + if err != nil { + return + } + + // Deserialize JSON to map[string]any + var body map[string]any + _ = json.Unmarshal(jsonData, &body) + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/embeddings", withModel(string(baseReq.Model))), + withBody(body), // Main request body. + withExtraBody(extraBody), // Merge ExtraBody fields. + ) if err != nil { return } - err = c.sendRequest(req, &resp) + if baseReq.EncodingFormat != EmbeddingEncodingFormatBase64 { + err = c.sendRequest(req, &res) + return + } + + base64Response := &EmbeddingResponseBase64{} + err = c.sendRequest(req, base64Response) + if err != nil { + return + } + res, err = base64Response.ToEmbeddingResponse() return } diff --git a/embeddings_test.go b/embeddings_test.go index 252f7a5a0..07f1262cb 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -1,39 +1,42 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "bytes" "context" "encoding/json" + "errors" "fmt" + "math" "net/http" + "reflect" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { - embeddedModels := []EmbeddingModel{ - AdaSimilarity, - BabbageSimilarity, - CurieSimilarity, - DavinciSimilarity, - AdaSearchDocument, - AdaSearchQuery, - BabbageSearchDocument, - BabbageSearchQuery, - CurieSearchDocument, - CurieSearchQuery, - DavinciSearchDocument, - DavinciSearchQuery, - AdaCodeSearchCode, - AdaCodeSearchText, - BabbageCodeSearchCode, - BabbageCodeSearchText, + embeddedModels := []openai.EmbeddingModel{ + openai.AdaSimilarity, + openai.BabbageSimilarity, + openai.CurieSimilarity, + openai.DavinciSimilarity, + openai.AdaSearchDocument, + openai.AdaSearchQuery, + openai.BabbageSearchDocument, + openai.BabbageSearchQuery, + openai.CurieSearchDocument, + openai.CurieSearchQuery, + openai.DavinciSearchDocument, + openai.DavinciSearchQuery, + openai.AdaCodeSearchCode, + openai.AdaCodeSearchText, + openai.BabbageCodeSearchCode, + openai.BabbageCodeSearchText, } for _, model := range embeddedModels { - embeddingReq := EmbeddingRequest{ + // test embedding request with strings (simple embedding request) + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -44,48 +47,281 @@ func TestEmbedding(t *testing.T) { // the AdaSearchQuery type marshaled, err := json.Marshal(embeddingReq) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + + // test embedding request with strings and extra_body param + embeddingReqWithExtraBody := openai.EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + } + marshaled, err = json.Marshal(embeddingReqWithExtraBody) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + + // test embedding request with strings + embeddingReqStrings := openai.EmbeddingRequestStrings{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + } + marshaled, err = json.Marshal(embeddingReqStrings) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + + // test embedding request with tokens + embeddingReqTokens := openai.EmbeddingRequestTokens{ + Input: [][]int{ + {464, 2057, 373, 12625, 290, 262, 46612}, + {6395, 6096, 286, 11525, 12083, 2581}, + }, + Model: model, + } + marshaled, err = json.Marshal(embeddingReqTokens) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } } } -func TestEmbeddingModel(t *testing.T) { - var em EmbeddingModel - err := em.UnmarshalText([]byte("text-similarity-ada-001")) - checks.NoError(t, err, "Could not marshal embedding model") +func TestEmbeddingEndpoint(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() - if em != AdaSimilarity { - t.Errorf("Model is not equal to AdaSimilarity") + sampleEmbeddings := []openai.Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, } - err = em.UnmarshalText([]byte("some-non-existent-model")) - checks.NoError(t, err, "Could not marshal embedding model") - if em != Unknown { - t.Errorf("Model is not equal to Unknown") + sampleBase64Embeddings := []openai.Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, } -} -func TestEmbeddingEndpoint(t *testing.T) { - server := test.NewTestServer() server.RegisterHandler( "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(EmbeddingResponse{}) + var req struct { + EncodingFormat openai.EmbeddingEncodingFormat `json:"encoding_format"` + User string `json:"user"` + } + _ = json.NewDecoder(r.Body).Decode(&req) + + var resBytes []byte + switch { + case req.User == "invalid": + w.WriteHeader(http.StatusBadRequest) + return + case req.EncodingFormat == openai.EmbeddingEncodingFormatBase64: + resBytes, _ = json.Marshal(openai.EmbeddingResponseBase64{Data: sampleBase64Embeddings}) + default: + resBytes, _ = json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) + } + fmt.Fprintln(w, string(resBytes)) + }, + ) + // test create embeddings with strings (simple embedding request) + res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{}) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (ExtraBody in request) + res, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + Dimensions: 1, + }, + ) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (ExtraBody in request and ) + _, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Input: make(chan int), // Channels are not serializable + Model: "example_model", + }, + ) + checks.HasError(t, err, "CreateEmbeddings error") + + // test failed (Serialize JSON error) + res, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + EncodingFormat: openai.EmbeddingEncodingFormatBase64, + }, + ) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestStrings{}) + checks.NoError(t, err, "CreateEmbeddings strings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with tokens + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestTokens{}) + checks.NoError(t, err, "CreateEmbeddings tokens error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test failed sendRequest + _, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + User: "invalid", + EncodingFormat: openai.EmbeddingEncodingFormatBase64, + }) + checks.HasError(t, err, "CreateEmbeddings error") +} + +func TestAzureEmbeddingEndpoint(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + + sampleEmbeddings := []openai.Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + } + + server.RegisterHandler( + "/openai/deployments/text-embedding-ada-002/embeddings", + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) fmt.Fprintln(w, string(resBytes)) }, ) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) + // test create embeddings with strings (simple embedding request) + res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + }) checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } +} + +func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { + type fields struct { + Object string + Data []openai.Base64Embedding + Model openai.EmbeddingModel + Usage openai.Usage + } + tests := []struct { + name string + fields fields + want openai.EmbeddingResponse + wantErr bool + }{ + { + name: "test embedding response base64 to embedding response", + fields: fields{ + Data: []openai.Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, + }, + }, + want: openai.EmbeddingResponse{ + Data: []openai.Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + }, + }, + wantErr: false, + }, + { + name: "Invalid embedding", + fields: fields{ + Data: []openai.Base64Embedding{ + { + Embedding: "----", + }, + }, + }, + want: openai.EmbeddingResponse{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.EmbeddingResponseBase64{ + Object: tt.fields.Object, + Data: tt.fields.Data, + Model: tt.fields.Model, + Usage: tt.fields.Usage, + } + got, err := r.ToEmbeddingResponse() + if (err != nil) != tt.wantErr { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDotProduct(t *testing.T) { + v1 := &openai.Embedding{Embedding: []float32{1, 2, 3}} + v2 := &openai.Embedding{Embedding: []float32{2, 4, 6}} + expected := float32(28.0) + + result, err := v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1, 0}} + expected = float32(0.0) + + result, err = v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + // Test for VectorLengthMismatchError + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1}} + _, err = v1.DotProduct(v2) + if !errors.Is(err, openai.ErrVectorLengthMismatch) { + t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) + } } diff --git a/engines.go b/engines.go index bb6a66ce4..5a0dba858 100644 --- a/engines.go +++ b/engines.go @@ -12,17 +12,21 @@ type Engine struct { Object string `json:"object"` Owner string `json:"owner"` Ready bool `json:"ready"` + + httpHeader } // EnginesList is a list of engines. type EnginesList struct { Engines []Engine `json:"data"` + + httpHeader } // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/engines")) if err != nil { return } @@ -38,7 +42,7 @@ func (c *Client) GetEngine( engineID string, ) (engine Engine, err error) { urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } diff --git a/engines_test.go b/engines_test.go new file mode 100644 index 000000000..d26aa5541 --- /dev/null +++ b/engines_test.go @@ -0,0 +1,47 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server. +func TestGetEngine(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.Engine{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.GetEngine(context.Background(), "text-davinci-003") + checks.NoError(t, err, "GetEngine error") +} + +// TestListEngines Tests the list engines endpoint of the API using the mocked server. +func TestListEngines(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.EnginesList{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.ListEngines(context.Background()) + checks.NoError(t, err, "ListEngines error") +} + +func TestListEnginesReturnError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTeapot) + }) + + _, err := client.ListEngines(context.Background()) + checks.HasError(t, err, "ListEngines did not fail") +} diff --git a/error.go b/error.go index 6354f43b5..8a74bd52c 100644 --- a/error.go +++ b/error.go @@ -3,21 +3,33 @@ package openai import ( "encoding/json" "fmt" + "strings" ) // APIError provides error information returned by the OpenAI API. +// InnerError struct is only valid for Azure OpenAI Service. type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` - HTTPStatusCode int `json:"-"` + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` + HTTPStatus string `json:"-"` + HTTPStatusCode int `json:"-"` + InnerError *InnerError `json:"innererror,omitempty"` } -// RequestError provides informations about generic request errors. +// InnerError Azure Content filtering. Only valid for Azure OpenAI Service. +type InnerError struct { + Code string `json:"code,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` +} + +// RequestError provides information about generic request errors. type RequestError struct { + HTTPStatus string HTTPStatusCode int Err error + Body []byte } type ErrorResponse struct { @@ -26,7 +38,7 @@ type ErrorResponse struct { func (e *APIError) Error() string { if e.HTTPStatusCode > 0 { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Message) } return e.Message @@ -41,12 +53,30 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { err = json.Unmarshal(rawMap["message"], &e.Message) if err != nil { - return + // If the parameter field of a function call is invalid as a JSON schema + // refs: https://github.com/sashabaranov/go-openai/issues/381 + var messages []string + err = json.Unmarshal(rawMap["message"], &messages) + if err != nil { + return + } + e.Message = strings.Join(messages, ", ") } - err = json.Unmarshal(rawMap["type"], &e.Type) - if err != nil { - return + // optional fields for azure openai + // refs: https://github.com/sashabaranov/go-openai/issues/343 + if _, ok := rawMap["type"]; ok { + err = json.Unmarshal(rawMap["type"], &e.Type) + if err != nil { + return + } + } + + if _, ok := rawMap["innererror"]; ok { + err = json.Unmarshal(rawMap["innererror"], &e.InnerError) + if err != nil { + return + } } // optional fields @@ -74,7 +104,10 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err) + return fmt.Sprintf( + "error, status code: %d, status: %s, message: %s, body: %s", + e.HTTPStatusCode, e.HTTPStatus, e.Err, e.Body, + ) } func (e *RequestError) Unwrap() error { diff --git a/error_accumulator.go b/error_accumulator.go deleted file mode 100644 index ca6cec6e3..000000000 --- a/error_accumulator.go +++ /dev/null @@ -1,51 +0,0 @@ -package openai - -import ( - "bytes" - "fmt" - "io" -) - -type errorAccumulator interface { - write(p []byte) error - unmarshalError() *ErrorResponse -} - -type errorBuffer interface { - io.Writer - Len() int - Bytes() []byte -} - -type defaultErrorAccumulator struct { - buffer errorBuffer - unmarshaler unmarshaler -} - -func newErrorAccumulator() errorAccumulator { - return &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - unmarshaler: &jsonUnmarshaler{}, - } -} - -func (e *defaultErrorAccumulator) write(p []byte) error { - _, err := e.buffer.Write(p) - if err != nil { - return fmt.Errorf("error accumulator write error, %w", err) - } - return nil -} - -func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) { - if e.buffer.Len() == 0 { - return - } - - err := e.unmarshaler.unmarshal(e.buffer.Bytes(), &errResp) - if err != nil { - errResp = nil - } - - return -} diff --git a/error_accumulator_test.go b/error_accumulator_test.go deleted file mode 100644 index ecf954d58..000000000 --- a/error_accumulator_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "bytes" - "context" - "errors" - "net/http" - "testing" - - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" -) - -var ( - errTestUnmarshalerFailed = errors.New("test unmarshaler failed") - errTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") -) - -type ( - failingUnMarshaller struct{} - failingErrorBuffer struct{} -) - -func (b *failingErrorBuffer) Write(_ []byte) (n int, err error) { - return 0, errTestErrorAccumulatorWriteFailed -} - -func (b *failingErrorBuffer) Len() int { - return 0 -} - -func (b *failingErrorBuffer) Bytes() []byte { - return []byte{} -} - -func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error { - return errTestUnmarshalerFailed -} - -func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) { - accumulator := &defaultErrorAccumulator{ - buffer: &bytes.Buffer{}, - unmarshaler: &failingUnMarshaller{}, - } - - respErr := accumulator.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil with empty buffer: %v", respErr) - } - - err := accumulator.write([]byte("{")) - if err != nil { - t.Fatalf("%+v", err) - } - - respErr = accumulator.unmarshalError() - if respErr != nil { - t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) - } -} - -func TestErrorByteWriteErrors(t *testing.T) { - accumulator := &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, - unmarshaler: &jsonUnmarshaler{}, - } - err := accumulator.write([]byte("{")) - if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { - t.Fatalf("Did not return error when write failed: %v", err) - } -} - -func TestErrorAccumulatorWriteErrors(t *testing.T) { - var err error - server := test.NewTestServer() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "error", 200) - }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - - ctx := context.Background() - - stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) - checks.NoError(t, err) - - stream.errAccumulator = &defaultErrorAccumulator{ - buffer: &failingErrorBuffer{}, - unmarshaler: &jsonUnmarshaler{}, - } - - _, err = stream.Recv() - checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) -} diff --git a/error_test.go b/error_test.go new file mode 100644 index 000000000..48cbe4f29 --- /dev/null +++ b/error_test.go @@ -0,0 +1,279 @@ +package openai_test + +import ( + "errors" + "net/http" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai" +) + +func TestAPIErrorUnmarshalJSON(t *testing.T) { + type testCase struct { + name string + response string + hasError bool + checkFunc func(t *testing.T, apiErr openai.APIError) + } + testCases := []testCase{ + // testcase for message field + { + name: "parse succeeds when the message is string", + response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorMessage(t, apiErr, "foo") + }, + }, + { + name: "parse succeeds when the message is array with single item", + response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorMessage(t, apiErr, "foo") + }, + }, + { + name: "parse succeeds when the message is array with multiple items", + response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorMessage(t, apiErr, "foo, bar, baz") + }, + }, + { + name: "parse succeeds when the message is empty array", + response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorMessage(t, apiErr, "") + }, + }, + { + name: "parse succeeds when the message is null", + response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorMessage(t, apiErr, "") + }, + }, + { + name: "parse succeeds when the innerError is not exists (Azure Openai)", + response: `{ + "message": "test message", + "type": null, + "param": "prompt", + "code": "content_filter", + "status": 400, + "innererror": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_result": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": true, + "severity": "medium" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + } + } + }`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{ + Code: "ResponsibleAIPolicyViolation", + ContentFilterResults: openai.ContentFilterResults{ + Hate: openai.Hate{ + Filtered: false, + Severity: "safe", + }, + SelfHarm: openai.SelfHarm{ + Filtered: false, + Severity: "safe", + }, + Sexual: openai.Sexual{ + Filtered: true, + Severity: "medium", + }, + Violence: openai.Violence{ + Filtered: false, + Severity: "safe", + }, + }, + }) + }, + }, + { + name: "parse succeeds when the innerError is empty (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) + }, + }, + { + name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, + hasError: true, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) + }, + }, + { + name: "parse failed when the message is object", + response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is int", + response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is float", + response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is bool", + response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is not exists", + response: `{"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + // testcase for code field + { + name: "parse succeeds when the code is int", + response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorCode(t, apiErr, 418) + }, + }, + { + name: "parse succeeds when the code is string", + response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorCode(t, apiErr, "teapot") + }, + }, + { + name: "parse succeeds when the code is not exists", + response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorCode(t, apiErr, nil) + }, + }, + // testcase for param field + { + name: "parse failed when the param is bool", + response: `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}`, + hasError: true, + }, + // testcase for type field + { + name: "parse failed when the type is bool", + response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}`, + hasError: true, + }, + // testcase for error response + { + name: "parse failed when the response is invalid json", + response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: true, + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorCode(t, apiErr, nil) + assertAPIErrorMessage(t, apiErr, "") + assertAPIErrorParam(t, apiErr, nil) + assertAPIErrorType(t, apiErr, "") + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var apiErr openai.APIError + err := apiErr.UnmarshalJSON([]byte(tc.response)) + if (err != nil) != tc.hasError { + t.Errorf("Unexpected error: %v", err) + } + if tc.checkFunc != nil { + tc.checkFunc(t, apiErr) + } + }) + } +} + +func assertAPIErrorMessage(t *testing.T, apiErr openai.APIError, expected string) { + if apiErr.Message != expected { + t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected) + } +} + +func assertAPIErrorInnerError(t *testing.T, apiErr openai.APIError, expected interface{}) { + if !reflect.DeepEqual(apiErr.InnerError, expected) { + t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) + } +} + +func assertAPIErrorCode(t *testing.T, apiErr openai.APIError, expected interface{}) { + switch v := apiErr.Code.(type) { + case int: + if v != expected { + t.Errorf("Unexpected APIError code integer: %d; expected %d", v, expected) + } + case string: + if v != expected { + t.Errorf("Unexpected APIError code string: %s; expected %s", v, expected) + } + case nil: + default: + t.Errorf("Unexpected APIError error code type: %T", v) + } +} + +func assertAPIErrorParam(t *testing.T, apiErr openai.APIError, expected *string) { + if apiErr.Param != expected { + t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected) + } +} + +func assertAPIErrorType(t *testing.T, apiErr openai.APIError, typ string) { + if apiErr.Type != typ { + t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ) + } +} + +func TestRequestError(t *testing.T) { + var err error = &openai.RequestError{ + HTTPStatusCode: http.StatusTeapot, + Err: errors.New("i am a teapot"), + } + + var reqErr *openai.RequestError + if !errors.As(err, &reqErr) { + t.Fatalf("Error is not a RequestError: %+v", err) + } + + if reqErr.HTTPStatusCode != 418 { + t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode) + } + + if reqErr.Unwrap() == nil { + t.Fatalf("Empty request error occurred") + } +} diff --git a/example_test.go b/example_test.go index da253806d..5910ffb84 100644 --- a/example_test.go +++ b/example_test.go @@ -28,7 +28,6 @@ func Example() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return @@ -60,7 +59,7 @@ func ExampleClient_CreateChatCompletionStream() { } defer stream.Close() - fmt.Printf("Stream response: ") + fmt.Print("Stream response: ") for { var response openai.ChatCompletionStreamResponse response, err = stream.Recv() @@ -74,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() { return } - fmt.Printf(response.Choices[0].Delta.Content) + fmt.Println(response.Choices[0].Delta.Content) } } @@ -83,7 +82,7 @@ func ExampleClient_CreateCompletion() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, @@ -100,7 +99,7 @@ func ExampleClient_CreateCompletionStream() { stream, err := client.CreateCompletionStream( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, @@ -305,8 +304,7 @@ func Example_chatbot() { func ExampleDefaultAzureConfig() { azureKey := os.Getenv("AZURE_OPENAI_API_KEY") // Your azure API key azureEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // Your azure OpenAI endpoint - azureModel := os.Getenv("AZURE_OPENAI_MODEL") // Your model deployment name - config := openai.DefaultAzureConfig(azureKey, azureEndpoint, azureModel) + config := openai.DefaultAzureConfig(azureKey, azureEndpoint) client := openai.NewClientWithConfig(config) resp, err := client.CreateChatCompletion( context.Background(), @@ -320,7 +318,6 @@ func ExampleDefaultAzureConfig() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return diff --git a/examples/completion-with-tool/main.go b/examples/completion-with-tool/main.go new file mode 100644 index 000000000..26126e41b --- /dev/null +++ b/examples/completion-with-tool/main.go @@ -0,0 +1,94 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + ctx := context.Background() + client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) + + // describe the function & its inputs + params := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + } + f := openai.FunctionDefinition{ + Name: "get_current_weather", + Description: "Get the current weather in a given location", + Parameters: params, + } + t := openai.Tool{ + Type: openai.ToolTypeFunction, + Function: &f, + } + + // simulate user asking a question that requires the function + dialogue := []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"}, + } + fmt.Printf("Asking OpenAI '%v' and providing it a '%v()' function...\n", + dialogue[0].Content, f.Name) + resp, err := client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4TurboPreview, + Messages: dialogue, + Tools: []openai.Tool{t}, + }, + ) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("Completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + msg := resp.Choices[0].Message + if len(msg.ToolCalls) != 1 { + fmt.Printf("Completion error: len(toolcalls): %v\n", len(msg.ToolCalls)) + return + } + + // simulate calling the function & responding to OpenAI + dialogue = append(dialogue, msg) + fmt.Printf("OpenAI called us back wanting to invoke our function '%v' with params '%v'\n", + msg.ToolCalls[0].Function.Name, msg.ToolCalls[0].Function.Arguments) + dialogue = append(dialogue, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: "Sunny and 80 degrees.", + Name: msg.ToolCalls[0].Function.Name, + ToolCallID: msg.ToolCalls[0].ID, + }) + fmt.Printf("Sending OpenAI our '%v()' function's response and requesting the reply to the original question...\n", + f.Name) + resp, err = client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4TurboPreview, + Messages: dialogue, + Tools: []openai.Tool{t}, + }, + ) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("2nd completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + + // display OpenAI's response to the original question utilizing our function + msg = resp.Choices[0].Message + fmt.Printf("OpenAI answered the original request with: %v\n", + msg.Content) +} diff --git a/examples/completion/main.go b/examples/completion/main.go index 22af1fd82..8c5cbd5ca 100644 --- a/examples/completion/main.go +++ b/examples/completion/main.go @@ -13,7 +13,7 @@ func main() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, diff --git a/files.go b/files.go index b701b9454..edc9f2a20 100644 --- a/files.go +++ b/files.go @@ -14,20 +14,77 @@ type FileRequest struct { Purpose string `json:"purpose"` } +// PurposeType represents the purpose of the file when uploading. +type PurposeType string + +const ( + PurposeFineTune PurposeType = "fine-tune" + PurposeFineTuneResults PurposeType = "fine-tune-results" + PurposeAssistants PurposeType = "assistants" + PurposeAssistantsOutput PurposeType = "assistants_output" + PurposeBatch PurposeType = "batch" +) + +// FileBytesRequest represents a file upload request. +type FileBytesRequest struct { + // the name of the uploaded file in OpenAI + Name string + // the bytes of the file + Bytes []byte + // the purpose of the file + Purpose PurposeType +} + // File struct represents an OpenAPI file. type File struct { - Bytes int `json:"bytes"` - CreatedAt int64 `json:"created_at"` - ID string `json:"id"` - FileName string `json:"filename"` - Object string `json:"object"` - Owner string `json:"owner"` - Purpose string `json:"purpose"` + Bytes int `json:"bytes"` + CreatedAt int64 `json:"created_at"` + ID string `json:"id"` + FileName string `json:"filename"` + Object string `json:"object"` + Status string `json:"status"` + Purpose string `json:"purpose"` + StatusDetails string `json:"status_details"` + + httpHeader } // FilesList is a list of files that belong to the user or organization. type FilesList struct { Files []File `json:"data"` + + httpHeader +} + +// CreateFileBytes uploads bytes directly to OpenAI without requiring a local file. +func (c *Client) CreateFileBytes(ctx context.Context, request FileBytesRequest) (file File, err error) { + var b bytes.Buffer + reader := bytes.NewReader(request.Bytes) + builder := c.createFormBuilder(&b) + + err = builder.WriteField("purpose", string(request.Purpose)) + if err != nil { + return + } + + err = builder.CreateFormFileReader("file", reader, request.Name) + if err != nil { + return + } + + err = builder.Close() + if err != nil { + return + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), + withBody(&b), withContentType(builder.FormDataContentType())) + if err != nil { + return + } + + err = c.sendRequest(req, &file) + return } // CreateFile uploads a jsonl file to GPT3 @@ -36,7 +93,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File var b bytes.Buffer builder := c.createFormBuilder(&b) - err = builder.writeField("purpose", request.Purpose) + err = builder.WriteField("purpose", request.Purpose) if err != nil { return } @@ -45,32 +102,31 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File if err != nil { return } + defer fileData.Close() - err = builder.createFormFile("file", fileData) + err = builder.CreateFormFile("file", fileData) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), + withBody(&b), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.formDataContentType()) - err = c.sendRequest(req, &file) - return } // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/files/"+fileID)) if err != nil { return } @@ -82,7 +138,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/files")) if err != nil { return } @@ -95,7 +151,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { // such as the file name and purpose. func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } @@ -103,3 +159,13 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err err = c.sendRequest(req, &file) return } + +func (c *Client) GetFileContent(ctx context.Context, fileID string) (content RawResponse, err error) { + urlSuffix := fmt.Sprintf("/files/%s/content", fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + return c.sendRequestRaw(req) +} diff --git a/files_api_test.go b/files_api_test.go new file mode 100644 index 000000000..aa4fda458 --- /dev/null +++ b/files_api_test.go @@ -0,0 +1,196 @@ +package openai_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strconv" + "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestFileBytesUpload(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: openai.PurposeFineTune, + } + _, err := client.CreateFileBytes(context.Background(), req) + checks.NoError(t, err, "CreateFile error") +} + +func TestFileUpload(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.FileRequest{ + FileName: "test.go", + FilePath: "client.go", + Purpose: "fine-tune", + } + _, err := client.CreateFile(context.Background(), req) + checks.NoError(t, err, "CreateFile error") +} + +// handleCreateFile Handles the images endpoint by the test server. +func handleCreateFile(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // edits only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + err = r.ParseMultipartForm(1024 * 1024 * 1024) + if err != nil { + http.Error(w, "file is more than 1GB", http.StatusInternalServerError) + return + } + + values := r.Form + var purpose string + for key, value := range values { + if key == "purpose" { + purpose = value[0] + } + } + file, header, err := r.FormFile("file") + if err != nil { + return + } + defer file.Close() + + fileReq := openai.File{ + Bytes: int(header.Size), + ID: strconv.Itoa(int(time.Now().Unix())), + FileName: header.Filename, + Purpose: purpose, + CreatedAt: time.Now().Unix(), + Object: "test-objecct", + } + + resBytes, _ = json.Marshal(fileReq) + fmt.Fprint(w, string(resBytes)) +} + +func TestDeleteFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef", func(http.ResponseWriter, *http.Request) {}) + err := client.DeleteFile(context.Background(), "deadbeef") + checks.NoError(t, err, "DeleteFile error") +} + +func TestListFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FilesList{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.ListFiles(context.Background()) + checks.NoError(t, err, "ListFiles error") +} + +func TestGetFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.File{}) + fmt.Fprintln(w, string(resBytes)) + }) + _, err := client.GetFile(context.Background(), "deadbeef") + checks.NoError(t, err, "GetFile error") +} + +func TestGetFileContent(t *testing.T) { + wantRespJsonl := `{"prompt": "foo", "completion": "foo"} +{"prompt": "bar", "completion": "bar"} +{"prompt": "baz", "completion": "baz"} +` + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + // edits only accepts GET requests + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + fmt.Fprint(w, wantRespJsonl) + }) + + content, err := client.GetFileContent(context.Background(), "deadbeef") + checks.NoError(t, err, "GetFileContent error") + defer content.Close() + + actual, _ := io.ReadAll(content) + if string(actual) != wantRespJsonl { + t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual)) + } +} + +func TestGetFileContentReturnError(t *testing.T) { + wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts." + wantType := "invalid_request_error" + wantErrorResp := `{ + "error": { + "message": "` + wantMessage + `", + "type": "` + wantType + `", + "param": null, + "code": null + } +}` + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, wantErrorResp) + }) + + _, err := client.GetFileContent(context.Background(), "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + + apiErr := &openai.APIError{} + if !errors.As(err, &apiErr) { + t.Fatalf("Did not return APIError: %+v\n", apiErr) + } + if apiErr.Message != wantMessage { + t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message) + return + } + if apiErr.Type != wantType { + t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type) + return + } +} + +func TestGetFileContentReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files/deadbeef/content", func(http.ResponseWriter, *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.GetFileContent(ctx, "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} diff --git a/files_test.go b/files_test.go index bb06498c8..486ef892e 100644 --- a/files_test.go +++ b/files_test.go @@ -1,83 +1,61 @@ package openai //nolint:testpackage // testing private field import ( - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" - "encoding/json" "fmt" "io" - "net/http" "os" - "strconv" "testing" - "time" + + utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" ) -func TestFileUpload(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/files", handleCreateFile) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" +func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" client := NewClientWithConfig(config) - ctx := context.Background() + mockBuilder := &mockFormBuilder{} + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return mockBuilder + } - req := FileRequest{ - FileName: "test.go", - FilePath: "client.go", - Purpose: "fine-tune", + ctx := context.Background() + req := FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: PurposeAssistants, } - _, err = client.CreateFile(ctx, req) - checks.NoError(t, err, "CreateFile error") -} -// handleCreateFile Handles the images endpoint by the test server. -func handleCreateFile(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte + mockError := fmt.Errorf("mockWriteField error") + mockBuilder.mockWriteField = func(string, string) error { + return mockError + } + _, err := client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") - // edits only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + mockError = fmt.Errorf("mockCreateFormFile error") + mockBuilder.mockWriteField = func(string, string) error { + return nil } - err = r.ParseMultipartForm(1024 * 1024 * 1024) - if err != nil { - http.Error(w, "file is more than 1GB", http.StatusInternalServerError) - return + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return mockError } + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") - values := r.Form - var purpose string - for key, value := range values { - if key == "purpose" { - purpose = value[0] - } + mockError = fmt.Errorf("mockClose error") + mockBuilder.mockWriteField = func(string, string) error { + return nil } - file, header, err := r.FormFile("file") - if err != nil { - return + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return nil } - defer file.Close() - - var fileReq = File{ - Bytes: int(header.Size), - ID: strconv.Itoa(int(time.Now().Unix())), - FileName: header.Filename, - Purpose: purpose, - CreatedAt: time.Now().Unix(), - Object: "test-objecct", - Owner: "test-owner", + mockBuilder.mockClose = func() error { + return mockError } - - resBytes, _ = json.Marshal(fileReq) - fmt.Fprint(w, string(resBytes)) + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") } func TestFileUploadWithFailingFormBuilder(t *testing.T) { @@ -85,7 +63,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { config.BaseURL = "" client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) utils.FormBuilder { return mockBuilder } @@ -124,6 +102,9 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { return mockError } _, err = client.CreateFile(ctx, req) + if err == nil { + t.Fatal("CreateFile should return error if form builder fails") + } checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") } @@ -140,3 +121,20 @@ func TestFileUploadWithNonExistentPath(t *testing.T) { _, err := client.CreateFile(ctx, req) checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") } +func TestCreateFileRequestBuilderFailure(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return &mockFormBuilder{ + mockWriteField: func(string, string) error { return nil }, + mockCreateFormFile: func(string, *os.File) error { return nil }, + mockClose: func() error { return nil }, + } + } + + _, err := client.CreateFile(context.Background(), FileRequest{FilePath: "client.go"}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateFile should return error if request builder fails") +} diff --git a/fine_tunes.go b/fine_tunes.go index a1218670f..74b47bf3f 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -6,6 +6,9 @@ import ( "net/http" ) +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneRequest struct { TrainingFile string `json:"training_file"` ValidationFile string `json:"validation_file,omitempty"` @@ -21,6 +24,9 @@ type FineTuneRequest struct { Suffix string `json:"suffix,omitempty"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTune struct { ID string `json:"id"` Object string `json:"object"` @@ -35,8 +41,13 @@ type FineTune struct { ValidationFiles []File `json:"validation_files"` TrainingFiles []File `json:"training_files"` UpdatedAt int64 `json:"updated_at"` + + httpHeader } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEvent struct { Object string `json:"object"` CreatedAt int64 `json:"created_at"` @@ -44,6 +55,9 @@ type FineTuneEvent struct { Message string `json:"message"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneHyperParams struct { BatchSize int `json:"batch_size"` LearningRateMultiplier float64 `json:"learning_rate_multiplier"` @@ -51,24 +65,43 @@ type FineTuneHyperParams struct { PromptLossWeight float64 `json:"prompt_loss_weight"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` + + httpHeader } + +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` + + httpHeader } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) if err != nil { return } @@ -78,8 +111,11 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r } // CancelFineTune cancel a fine-tune job. +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) //nolint:lll //this method is deprecated if err != nil { return } @@ -88,8 +124,11 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes")) if err != nil { return } @@ -98,9 +137,12 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } @@ -109,8 +151,11 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID)) if err != nil { return } @@ -119,8 +164,11 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events")) if err != nil { return } diff --git a/fine_tunes_test.go b/fine_tunes_test.go index c60254993..2ab6817f7 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -1,30 +1,30 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuneID = "fine-tune-id" // TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. func TestFineTunes(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler( "/v1/fine-tunes", func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodGet { - resBytes, _ = json.Marshal(FineTuneList{}) + resBytes, _ = json.Marshal(openai.FineTuneList{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -32,8 +32,8 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTune{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTune{}) fmt.Fprintln(w, string(resBytes)) }, ) @@ -43,9 +43,9 @@ func TestFineTunes(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodDelete { - resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) + resBytes, _ = json.Marshal(openai.FineTuneDeleteResponse{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -53,27 +53,18 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuneEventList{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuneEventList{}) fmt.Fprintln(w, string(resBytes)) }, ) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) ctx := context.Background() - _, err = client.ListFineTunes(ctx) + _, err := client.ListFineTunes(ctx) checks.NoError(t, err, "ListFineTunes error") - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + _, err = client.CreateFineTune(ctx, openai.FineTuneRequest{}) checks.NoError(t, err, "CreateFineTune error") _, err = client.CancelFineTune(ctx, testFineTuneID) diff --git a/fine_tuning_job.go b/fine_tuning_job.go new file mode 100644 index 000000000..5a9f54a92 --- /dev/null +++ b/fine_tuning_job.go @@ -0,0 +1,159 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type FineTuningJob struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + FinishedAt int64 `json:"finished_at"` + Model string `json:"model"` + FineTunedModel string `json:"fine_tuned_model,omitempty"` + OrganizationID string `json:"organization_id"` + Status string `json:"status"` + Hyperparameters Hyperparameters `json:"hyperparameters"` + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + ResultFiles []string `json:"result_files"` + TrainedTokens int `json:"trained_tokens"` + + httpHeader +} + +type Hyperparameters struct { + Epochs any `json:"n_epochs,omitempty"` + LearningRateMultiplier any `json:"learning_rate_multiplier,omitempty"` + BatchSize any `json:"batch_size,omitempty"` +} + +type FineTuningJobRequest struct { + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + Model string `json:"model,omitempty"` + Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"` + Suffix string `json:"suffix,omitempty"` +} + +type FineTuningJobEventList struct { + Object string `json:"object"` + Data []FineTuneEvent `json:"data"` + HasMore bool `json:"has_more"` + + httpHeader +} + +type FineTuningJobEvent struct { + Object string `json:"object"` + ID string `json:"id"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` + Data any `json:"data"` + Type string `json:"type"` +} + +// CreateFineTuningJob create a fine tuning job. +func (c *Client) CreateFineTuningJob( + ctx context.Context, + request FineTuningJobRequest, +) (response FineTuningJob, err error) { + urlSuffix := "/fine_tuning/jobs" + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelFineTuningJob cancel a fine tuning job. +func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel")) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveFineTuningJob retrieve a fine tuning job. +func (c *Client) RetrieveFineTuningJob( + ctx context.Context, + fineTuningJobID string, +) (response FineTuningJob, err error) { + urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type listFineTuningJobEventsParameters struct { + after *string + limit *int +} + +type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) + +func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.after = &after + } +} + +func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.limit = &limit + } +} + +// ListFineTuningJobs list fine tuning jobs events. +func (c *Client) ListFineTuningJobEvents( + ctx context.Context, + fineTuningJobID string, + setters ...ListFineTuningJobEventsParameter, +) (response FineTuningJobEventList, err error) { + parameters := &listFineTuningJobEventsParameters{ + after: nil, + limit: nil, + } + + for _, setter := range setters { + setter(parameters) + } + + urlValues := url.Values{} + if parameters.after != nil { + urlValues.Add("after", *parameters.after) + } + if parameters.limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit)) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go new file mode 100644 index 000000000..5f63ef24c --- /dev/null +++ b/fine_tuning_job_test.go @@ -0,0 +1,106 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +const testFineTuninigJobID = "fine-tuning-job-id" + +// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server. +func TestFineTuningJob(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler( + "/v1/fine_tuning/jobs", + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{ + Object: "fine_tuning.job", + ID: testFineTuninigJobID, + Model: "davinci-002", + CreatedAt: 1692661014, + FinishedAt: 1692661190, + FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + OrganizationID: "org-123", + ResultFiles: []string{"file-abc123"}, + Status: "succeeded", + ValidationFile: "", + TrainingFile: "file-abc123", + Hyperparameters: openai.Hyperparameters{ + Epochs: "auto", + LearningRateMultiplier: "auto", + BatchSize: "auto", + }, + TrainedTokens: 5768, + }) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID, + func(w http.ResponseWriter, _ *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(openai.FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJobEventList{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + ctx := context.Background() + + _, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{}) + checks.NoError(t, err, "CreateFineTuningJob error") + + _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "CancelFineTuningJob error") + + _, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "RetrieveFineTuningJob error") + + _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + openai.ListFineTuningJobEventsWithAfter("last-event-id"), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + openai.ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + openai.ListFineTuningJobEventsWithAfter("last-event-id"), + openai.ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") +} diff --git a/form_builder.go b/form_builder.go deleted file mode 100644 index 7fbb1643a..000000000 --- a/form_builder.go +++ /dev/null @@ -1,49 +0,0 @@ -package openai - -import ( - "io" - "mime/multipart" - "os" -) - -type formBuilder interface { - createFormFile(fieldname string, file *os.File) error - writeField(fieldname, value string) error - close() error - formDataContentType() string -} - -type defaultFormBuilder struct { - writer *multipart.Writer -} - -func newFormBuilder(body io.Writer) *defaultFormBuilder { - return &defaultFormBuilder{ - writer: multipart.NewWriter(body), - } -} - -func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error { - fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) - if err != nil { - return err - } - - _, err = io.Copy(fieldWriter, file) - if err != nil { - return err - } - return nil -} - -func (fb *defaultFormBuilder) writeField(fieldname, value string) error { - return fb.writer.WriteField(fieldname, value) -} - -func (fb *defaultFormBuilder) close() error { - return fb.writer.Close() -} - -func (fb *defaultFormBuilder) formDataContentType() string { - return fb.writer.FormDataContentType() -} diff --git a/form_builder_test.go b/form_builder_test.go deleted file mode 100644 index 78e2ec968..000000000 --- a/form_builder_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - - "bytes" - "errors" - "os" - "testing" -) - -type failingWriter struct { -} - -var errMockFailingWriterError = errors.New("mock writer failed") - -func (*failingWriter) Write([]byte) (int, error) { - return 0, errMockFailingWriterError -} - -func TestFormBuilderWithFailingWriter(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") - if err != nil { - t.Errorf("Error creating tmp file: %v", err) - } - defer file.Close() - defer os.Remove(file.Name()) - - builder := newFormBuilder(&failingWriter{}) - err = builder.createFormFile("file", file) - checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") -} - -func TestFormBuilderWithClosedFile(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") - if err != nil { - t.Errorf("Error creating tmp file: %v", err) - } - file.Close() - defer os.Remove(file.Name()) - - body := &bytes.Buffer{} - builder := newFormBuilder(body) - err = builder.createFormFile("file", file) - checks.HasError(t, err, "formbuilder should return error if file is closed") - checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") -} diff --git a/image.go b/image.go index 21703bda7..84b9daf02 100644 --- a/image.go +++ b/image.go @@ -3,8 +3,8 @@ package openai import ( "bytes" "context" + "io" "net/http" - "os" "strconv" ) @@ -13,38 +13,117 @@ const ( CreateImageSize256x256 = "256x256" CreateImageSize512x512 = "512x512" CreateImageSize1024x1024 = "1024x1024" + + // dall-e-3 supported only. + CreateImageSize1792x1024 = "1792x1024" + CreateImageSize1024x1792 = "1024x1792" + + // gpt-image-1 supported only. + CreateImageSize1536x1024 = "1536x1024" // Landscape + CreateImageSize1024x1536 = "1024x1536" // Portrait ) const ( - CreateImageResponseFormatURL = "url" + // dall-e-2 and dall-e-3 only. CreateImageResponseFormatB64JSON = "b64_json" + CreateImageResponseFormatURL = "url" +) + +const ( + CreateImageModelDallE2 = "dall-e-2" + CreateImageModelDallE3 = "dall-e-3" + CreateImageModelGptImage1 = "gpt-image-1" +) + +const ( + CreateImageQualityHD = "hd" + CreateImageQualityStandard = "standard" + + // gpt-image-1 only. + CreateImageQualityHigh = "high" + CreateImageQualityMedium = "medium" + CreateImageQualityLow = "low" +) + +const ( + // dall-e-3 only. + CreateImageStyleVivid = "vivid" + CreateImageStyleNatural = "natural" +) + +const ( + // gpt-image-1 only. + CreateImageBackgroundTransparent = "transparent" + CreateImageBackgroundOpaque = "opaque" +) + +const ( + // gpt-image-1 only. + CreateImageModerationLow = "low" +) + +const ( + // gpt-image-1 only. + CreateImageOutputFormatPNG = "png" + CreateImageOutputFormatJPEG = "jpeg" + CreateImageOutputFormatWEBP = "webp" ) // ImageRequest represents the request structure for the image API. type ImageRequest struct { - Prompt string `json:"prompt,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style string `json:"style,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` + OutputCompression int `json:"output_compression,omitempty"` + OutputFormat string `json:"output_format,omitempty"` } // ImageResponse represents a response structure for image API. type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + Usage ImageResponseUsage `json:"usage,omitempty"` + + httpHeader +} + +// ImageResponseInputTokensDetails represents the token breakdown for input tokens. +type ImageResponseInputTokensDetails struct { + TextTokens int `json:"text_tokens,omitempty"` + ImageTokens int `json:"image_tokens,omitempty"` +} + +// ImageResponseUsage represents the token usage information for image API. +type ImageResponseUsage struct { + TotalTokens int `json:"total_tokens,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"` } // ImageResponseDataInner represents a response data structure for image API. type ImageResponseDataInner struct { - URL string `json:"url,omitempty"` - B64JSON string `json:"b64_json,omitempty"` + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` } // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } @@ -53,14 +132,42 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons return } +// WrapReader wraps an io.Reader with filename and Content-type. +func WrapReader(rdr io.Reader, filename string, contentType string) io.Reader { + return file{rdr, filename, contentType} +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f file) Name() string { + if f.name != "" { + return f.name + } else if named, ok := f.Reader.(interface{ Name() string }); ok { + return named.Name() + } + return "" +} + +func (f file) ContentType() string { + return f.contentType +} + // ImageEditRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageEditRequest struct { - Image *os.File `json:"image,omitempty"` - Mask *os.File `json:"mask,omitempty"` - Prompt string `json:"prompt,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` + Image io.Reader `json:"image,omitempty"` + Mask io.Reader `json:"mask,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -68,62 +175,70 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.createFormFile("image", request.Image) + // image, filename verification can be postponed + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } // mask, it is optional if request.Mask != nil { - err = builder.createFormFile("mask", request.Mask) + // filename verification can be postponed + err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return } } - err = builder.writeField("prompt", request.Prompt) + err = builder.WriteField("prompt", request.Prompt) if err != nil { return } - err = builder.writeField("n", strconv.Itoa(request.N)) + err = builder.WriteField("n", strconv.Itoa(request.N)) if err != nil { return } - err = builder.writeField("size", request.Size) + err = builder.WriteField("size", request.Size) if err != nil { return } - err = builder.writeField("response_format", request.ResponseFormat) + err = builder.WriteField("response_format", request.ResponseFormat) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } - urlSuffix := "/images/edits" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/edits", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } - req.Header.Set("Content-Type", builder.formDataContentType()) err = c.sendRequest(req, &response) return } // ImageVariRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageVariRequest struct { - Image *os.File `json:"image,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` + Image io.Reader `json:"image,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` } // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. @@ -132,40 +247,43 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.createFormFile("image", request.Image) + // image, filename verification can be postponed + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } - err = builder.writeField("n", strconv.Itoa(request.N)) + err = builder.WriteField("n", strconv.Itoa(request.N)) if err != nil { return } - err = builder.writeField("size", request.Size) + err = builder.WriteField("size", request.Size) if err != nil { return } - err = builder.writeField("response_format", request.ResponseFormat) + err = builder.WriteField("response_format", request.ResponseFormat) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } - //https://platform.openai.com/docs/api-reference/images/create-variation - urlSuffix := "/images/variations" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/variations", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } - req.Header.Set("Content-Type", builder.formDataContentType()) err = c.sendRequest(req, &response) return } diff --git a/image_api_test.go b/image_api_test.go new file mode 100644 index 000000000..f6057b77d --- /dev/null +++ b/image_api_test.go @@ -0,0 +1,214 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestImages(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/generations", handleImageEndpoint) + _, err := client.CreateImage(context.Background(), openai.ImageRequest{ + Prompt: "Lorem ipsum", + Model: openai.CreateImageModelDallE3, + N: 1, + Quality: openai.CreateImageQualityHD, + Size: openai.CreateImageSize1024x1024, + Style: openai.CreateImageStyleVivid, + ResponseFormat: openai.CreateImageResponseFormatURL, + User: "user", + }) + checks.NoError(t, err, "CreateImage error") +} + +// handleImageEndpoint Handles the images endpoint by the test server. +func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // images only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var imageReq openai.ImageRequest + if imageReq, err = getImageBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := openai.ImageResponse{ + Created: time.Now().Unix(), + } + for i := 0; i < imageReq.N; i++ { + imageData := openai.ImageResponseDataInner{} + switch imageReq.ResponseFormat { + case openai.CreateImageResponseFormatURL, "": + imageData.URL = "https://example.com/image.png" + case openai.CreateImageResponseFormatB64JSON: + // This decodes to "{}" in base64. + imageData.B64JSON = "e30K" + default: + http.Error(w, "invalid response format", http.StatusBadRequest) + return + } + res.Data = append(res.Data, imageData) + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getImageBody Returns the body of the request to create a image. +func getImageBody(r *http.Request) (openai.ImageRequest, error) { + image := openai.ImageRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return openai.ImageRequest{}, err + } + err = json.Unmarshal(reqBody, &image) + if err != nil { + return openai.ImageRequest{}, err + } + return image, nil +} + +func TestImageEdit(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) + if err != nil { + t.Fatalf("open origin file error: %v", err) + } + defer origin.Close() + + mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png")) + if err != nil { + t.Fatalf("open mask file error: %v", err) + } + defer mask.Close() + + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ + Image: origin, + Mask: mask, + Prompt: "There is a turtle in the pool", + N: 3, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, + }) + checks.NoError(t, err, "CreateImage error") +} + +func TestImageEditWithoutMask(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) + if err != nil { + t.Fatalf("open origin file error: %v", err) + } + defer origin.Close() + + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ + Image: origin, + Prompt: "There is a turtle in the pool", + N: 3, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, + }) + checks.NoError(t, err, "CreateImage error") +} + +// handleEditImageEndpoint Handles the images endpoint by the test server. +func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // images only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := openai.ImageResponse{ + Created: time.Now().Unix(), + Data: []openai.ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +} + +func TestImageVariation(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) + + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) + if err != nil { + t.Fatalf("open origin file error: %v", err) + } + defer origin.Close() + + _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ + Image: origin, + N: 3, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, + }) + checks.NoError(t, err, "CreateImage error") +} + +// handleVariateImageEndpoint Handles the images endpoint by the test server. +func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // images only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := openai.ImageResponse{ + Created: time.Now().Unix(), + Data: []openai.ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +} diff --git a/image_test.go b/image_test.go index 4a7dad58f..c2c8f42dc 100644 --- a/image_test.go +++ b/image_test.go @@ -1,404 +1,323 @@ package openai //nolint:testpackage // testing private field import ( - "github.com/sashabaranov/go-openai/internal/test" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test/checks" + "bytes" "context" - "encoding/json" "fmt" "io" - "net/http" "os" "testing" - "time" ) -func TestImages(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/generations", handleImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - req := ImageRequest{} - req.Prompt = "Lorem ipsum" - _, err = client.CreateImage(ctx, req) - checks.NoError(t, err, "CreateImage error") +type mockFormBuilder struct { + mockCreateFormFile func(string, *os.File) error + mockCreateFormFileReader func(string, io.Reader, string) error + mockWriteField func(string, string) error + mockClose func() error } -// handleImageEndpoint Handles the images endpoint by the test server. -func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var imageReq ImageRequest - if imageReq, err = getImageBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - res := ImageResponse{ - Created: time.Now().Unix(), - } - for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} - switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": - imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: - // This decodes to "{}" in base64. - imageData.B64JSON = "e30K" - default: - http.Error(w, "invalid response format", http.StatusBadRequest) - return - } - res.Data = append(res.Data, imageData) - } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) +func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + return fb.mockCreateFormFile(fieldname, file) } -// getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return ImageRequest{}, err - } - err = json.Unmarshal(reqBody, &image) - if err != nil { - return ImageRequest{}, err - } - return image, nil +func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + return fb.mockCreateFormFileReader(fieldname, r, filename) } -func TestImageEdit(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - mask, err := os.Create("mask.png") - if err != nil { - t.Error("open mask file error") - return - } - - defer func() { - mask.Close() - origin.Close() - os.Remove("mask.png") - os.Remove("image.png") - }() - - req := ImageEditRequest{ - Image: origin, - Mask: mask, - Prompt: "There is a turtle in the pool", - N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, - } - _, err = client.CreateEditImage(ctx, req) - checks.NoError(t, err, "CreateImage error") +func (fb *mockFormBuilder) WriteField(fieldname, value string) error { + return fb.mockWriteField(fieldname, value) } -func TestImageEditWithoutMask(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return - } - - defer func() { - origin.Close() - os.Remove("image.png") - }() +func (fb *mockFormBuilder) Close() error { + return fb.mockClose() +} - req := ImageEditRequest{ - Image: origin, - Prompt: "There is a turtle in the pool", - N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, - } - _, err = client.CreateEditImage(ctx, req) - checks.NoError(t, err, "CreateImage error") +func (fb *mockFormBuilder) FormDataContentType() string { + return "" } -// handleEditImageEndpoint Handles the images endpoint by the test server. -func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { - var resBytes []byte +func TestImageFormBuilderFailures(t *testing.T) { + ctx := context.Background() + mockFailedErr := fmt.Errorf("mock form builder fail") - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c } - responses := ImageResponse{ - Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ - { - URL: "test-url1", - B64JSON: "", + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageEditRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "mask", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { + if name == "mask" { + return mockFailedErr + } + return nil + } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } }, - { - URL: "test-url2", - B64JSON: "", + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "prompt", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "prompt" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } }, - { - URL: "test-url3", - B64JSON: "", + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, }, } - resBytes, _ = json.Marshal(responses) - fmt.Fprintln(w, string(resBytes)) -} - -func TestImageVariation(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - origin, err := os.Create("image.png") - if err != nil { - t.Error("open origin file error") - return + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateEditImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateEditImage should return error if form builder fails") + }) } - defer func() { - origin.Close() - os.Remove("image.png") - }() + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, + } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} - req := ImageVariRequest{ - Image: origin, - N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, - } - _, err = client.CreateVariImage(ctx, req) - checks.NoError(t, err, "CreateImage error") + _, err := client.CreateEditImage(ctx, ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateEditImage should return error if request builder fails") + }) } -// handleVariateImageEndpoint Handles the images endpoint by the test server. -func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { - var resBytes []byte +func TestVariImageFormBuilderFailures(t *testing.T) { + ctx := context.Background() + mockFailedErr := fmt.Errorf("mock form builder fail") - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c } - responses := ImageResponse{ - Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ - { - URL: "test-url1", - B64JSON: "", + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageVariRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } }, - { - URL: "test-url2", - B64JSON: "", + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } }, - { - URL: "test-url3", - B64JSON: "", + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, }, } - resBytes, _ = json.Marshal(responses) - fmt.Fprintln(w, string(resBytes)) -} - -type mockFormBuilder struct { - mockCreateFormFile func(string, *os.File) error - mockWriteField func(string, string) error - mockClose func() error -} - -func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error { - return fb.mockCreateFormFile(fieldname, file) -} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateVariImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + }) + } -func (fb *mockFormBuilder) writeField(fieldname, value string) error { - return fb.mockWriteField(fieldname, value) -} + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, + } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} -func (fb *mockFormBuilder) close() error { - return fb.mockClose() + _, err := client.CreateVariImage(ctx, ImageVariRequest{Image: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateVariImage should return error if request builder fails") + }) } -func (fb *mockFormBuilder) formDataContentType() string { - return "" -} +type testNamedReader struct{ io.Reader } -func TestImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) +func (testNamedReader) Name() string { return "named.txt" } - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { - return mockBuilder +func TestWrapReader(t *testing.T) { + r := bytes.NewBufferString("data") + wrapped := WrapReader(r, "file.png", "image/png") + f, ok := wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped reader missing Name or ContentType") } - ctx := context.Background() - - req := ImageEditRequest{ - Mask: &os.File{}, + if f.Name() != "file.png" { + t.Fatalf("expected name file.png, got %s", f.Name()) } - - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return mockFailedErr + if f.ContentType() != "image/png" { + t.Fatalf("expected content type image/png, got %s", f.ContentType()) } - _, err := client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { - if name == "mask" { - return mockFailedErr - } - return nil + // test name from underlying reader + nr := testNamedReader{Reader: bytes.NewBufferString("d")} + wrapped = WrapReader(nr, "", "text/plain") + f, ok = wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped named reader missing Name or ContentType") } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { - return nil - } - - var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil - } - - failForField = "prompt" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "n" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "size" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "response_format" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr - } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") -} - -func TestVariImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) - - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { - return mockBuilder + if f.Name() != "named.txt" { + t.Fatalf("expected name named.txt, got %s", f.Name()) } - ctx := context.Background() - - req := ImageVariRequest{} - - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return mockFailedErr + if f.ContentType() != "text/plain" { + t.Fatalf("expected content type text/plain, got %s", f.ContentType()) } - _, err := client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { - return nil + // no name provided + wrapped = WrapReader(bytes.NewBuffer(nil), "", "") + f2, ok := wrapped.(interface{ Name() string }) + if !ok { + t.Fatal("wrapped anonymous reader missing Name") } - - var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil - } - - failForField = "n" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - - failForField = "size" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - - failForField = "response_format" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr + if f2.Name() != "" { + t.Fatalf("expected empty name, got %s", f2.Name()) } - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") } diff --git a/internal/error_accumulator.go b/internal/error_accumulator.go new file mode 100644 index 000000000..3d3e805fe --- /dev/null +++ b/internal/error_accumulator.go @@ -0,0 +1,44 @@ +package openai + +import ( + "bytes" + "fmt" + "io" +) + +type ErrorAccumulator interface { + Write(p []byte) error + Bytes() []byte +} + +type errorBuffer interface { + io.Writer + Len() int + Bytes() []byte +} + +type DefaultErrorAccumulator struct { + Buffer errorBuffer +} + +func NewErrorAccumulator() ErrorAccumulator { + return &DefaultErrorAccumulator{ + Buffer: &bytes.Buffer{}, + } +} + +func (e *DefaultErrorAccumulator) Write(p []byte) error { + _, err := e.Buffer.Write(p) + if err != nil { + return fmt.Errorf("error accumulator write error, %w", err) + } + return nil +} + +func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) { + if e.Buffer.Len() == 0 { + return + } + errBytes = e.Buffer.Bytes() + return +} diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go new file mode 100644 index 000000000..f6c226c5e --- /dev/null +++ b/internal/error_accumulator_test.go @@ -0,0 +1,39 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestDefaultErrorAccumulator_WriteMultiple(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") + } + checks.NoError(t, ea.Write([]byte("{\"error\": \"test1\"}"))) + checks.NoError(t, ea.Write([]byte("{\"error\": \"test2\"}"))) + + expected := "{\"error\": \"test1\"}{\"error\": \"test2\"}" + if string(ea.Bytes()) != expected { + t.Fatalf("Expected %q, got %q", expected, ea.Bytes()) + } +} + +func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") + } + if len(ea.Bytes()) != 0 { + t.Fatal("Buffer should be empty initially") + } +} + +func TestDefaultErrorAccumulator_WriteError(t *testing.T) { + ea := &openai.DefaultErrorAccumulator{Buffer: &test.FailingErrorBuffer{}} + err := ea.Write([]byte("fail")) + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Write should propagate buffer errors") +} diff --git a/internal/form_builder.go b/internal/form_builder.go new file mode 100644 index 000000000..a17e820c0 --- /dev/null +++ b/internal/form_builder.go @@ -0,0 +1,112 @@ +package openai + +import ( + "fmt" + "io" + "mime/multipart" + "net/textproto" + "os" + "path/filepath" + "strings" +) + +type FormBuilder interface { + CreateFormFile(fieldname string, file *os.File) error + CreateFormFileReader(fieldname string, r io.Reader, filename string) error + WriteField(fieldname, value string) error + Close() error + FormDataContentType() string +} + +type DefaultFormBuilder struct { + writer *multipart.Writer +} + +func NewFormBuilder(body io.Writer) *DefaultFormBuilder { + return &DefaultFormBuilder{ + writer: multipart.NewWriter(body), + } +} + +func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + return fb.createFormFile(fieldname, file, file.Name()) +} + +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +// CreateFormFileReader creates a form field with a file reader. +// The filename in Content-Disposition is required. +func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + if filename == "" { + if f, ok := r.(interface{ Name() string }); ok { + filename = f.Name() + } + } + var contentType string + if f, ok := r.(interface{ ContentType() string }); ok { + contentType = f.ContentType() + } + + h := make(textproto.MIMEHeader) + h.Set( + "Content-Disposition", + fmt.Sprintf( + `form-data; name="%s"; filename="%s"`, + escapeQuotes(fieldname), + escapeQuotes(filepath.Base(filename)), + ), + ) + // content type is optional, but it can be set + if contentType != "" { + h.Set("Content-Type", contentType) + } + + fieldWriter, err := fb.writer.CreatePart(h) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, r) + if err != nil { + return err + } + + return nil +} + +func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { + if filename == "" { + return fmt.Errorf("filename cannot be empty") + } + + fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, r) + if err != nil { + return err + } + + return nil +} + +func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { + if fieldname == "" { + return fmt.Errorf("fieldname cannot be empty") + } + return fb.writer.WriteField(fieldname, value) +} + +func (fb *DefaultFormBuilder) Close() error { + return fb.writer.Close() +} + +func (fb *DefaultFormBuilder) FormDataContentType() string { + return fb.writer.FormDataContentType() +} diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go new file mode 100644 index 000000000..53ef11d23 --- /dev/null +++ b/internal/form_builder_test.go @@ -0,0 +1,190 @@ +package openai //nolint:testpackage // testing private field + +import ( + "errors" + "io" + + "github.com/sashabaranov/go-openai/internal/test/checks" + + "bytes" + "os" + "strings" + "testing" +) + +type mockFormBuilder struct { + mockCreateFormFile func(string, *os.File) error + mockWriteField func(string, string) error + mockClose func() error +} + +func (m *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + return m.mockCreateFormFile(fieldname, file) +} + +func (m *mockFormBuilder) WriteField(fieldname, value string) error { + return m.mockWriteField(fieldname, value) +} + +func (m *mockFormBuilder) Close() error { + return m.mockClose() +} + +func (m *mockFormBuilder) FormDataContentType() string { + return "" +} + +func TestCloseMethod(t *testing.T) { + t.Run("NormalClose", func(t *testing.T) { + body := &bytes.Buffer{} + builder := NewFormBuilder(body) + checks.NoError(t, builder.Close(), "正常关闭应成功") + }) + + t.Run("ErrorPropagation", func(t *testing.T) { + errorMock := errors.New("mock close error") + mockBuilder := &mockFormBuilder{ + mockClose: func() error { + return errorMock + }, + } + err := mockBuilder.Close() + checks.ErrorIs(t, err, errorMock, "应传递关闭错误") + }) +} + +type failingWriter struct { +} + +var errMockFailingWriterError = errors.New("mock writer failed") + +func (*failingWriter) Write([]byte) (int, error) { + return 0, errMockFailingWriterError +} + +func TestFormBuilderWithFailingWriter(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatalf("Error creating tmp file: %v", err) + } + defer file.Close() + + builder := NewFormBuilder(&failingWriter{}) + err = builder.CreateFormFile("file", file) + checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") +} + +func TestFormBuilderWithClosedFile(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatalf("Error creating tmp file: %v", err) + } + file.Close() + + body := &bytes.Buffer{} + builder := NewFormBuilder(body) + err = builder.CreateFormFile("file", file) + checks.HasError(t, err, "formbuilder should return error if file is closed") + checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") +} + +type failingReader struct { +} + +var errMockFailingReaderError = errors.New("mock reader failed") + +func (*failingReader) Read([]byte) (int, error) { + return 0, errMockFailingReaderError +} + +type readerWithNameAndContentType struct { + io.Reader +} + +func (*readerWithNameAndContentType) Name() string { + return "" +} + +func (*readerWithNameAndContentType) ContentType() string { + return "image/png" +} + +func TestFormBuilderWithReader(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatalf("Error creating tmp file: %v", err) + } + defer file.Close() + builder := NewFormBuilder(&failingWriter{}) + err = builder.CreateFormFileReader("file", file, file.Name()) + checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") + + builder = NewFormBuilder(&bytes.Buffer{}) + reader := &failingReader{} + err = builder.CreateFormFileReader("file", reader, "") + checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails") + + successReader := &bytes.Buffer{} + err = builder.CreateFormFileReader("file", successReader, "") + checks.NoError(t, err, "formbuilder should not return error") + + rnc := &readerWithNameAndContentType{Reader: &bytes.Buffer{}} + err = builder.CreateFormFileReader("file", rnc, "") + checks.NoError(t, err, "formbuilder should not return error") +} + +func TestFormDataContentType(t *testing.T) { + t.Run("ReturnsUnderlyingWriterContentType", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + contentType := builder.FormDataContentType() + if contentType == "" { + t.Errorf("expected non-empty content type, got empty string") + } + }) +} + +func TestWriteField(t *testing.T) { + t.Run("EmptyFieldNameShouldReturnError", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("", "some value") + checks.HasError(t, err, "fieldname is required") + }) + + t.Run("ValidFieldNameShouldSucceed", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("key", "value") + checks.NoError(t, err, "should write field without error") + }) +} + +func TestCreateFormFile(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "") + if err == nil { + t.Fatal("expected error for empty filename") + } + + builder = NewFormBuilder(&failingWriter{}) + err = builder.createFormFile("file", bytes.NewBufferString("data"), "name") + checks.ErrorIs(t, err, errMockFailingWriterError, "should propagate writer error") +} + +func TestCreateFormFileSuccess(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "foo.txt") + checks.NoError(t, err, "createFormFile should succeed") + + if !strings.Contains(buf.String(), "filename=\"foo.txt\"") { + t.Fatalf("expected filename header, got %q", buf.String()) + } +} diff --git a/internal/marshaller.go b/internal/marshaller.go new file mode 100644 index 000000000..223a4dc1c --- /dev/null +++ b/internal/marshaller.go @@ -0,0 +1,15 @@ +package openai + +import ( + "encoding/json" +) + +type Marshaller interface { + Marshal(value any) ([]byte, error) +} + +type JSONMarshaller struct{} + +func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) { + return json.Marshal(value) +} diff --git a/internal/marshaller_test.go b/internal/marshaller_test.go new file mode 100644 index 000000000..70694faed --- /dev/null +++ b/internal/marshaller_test.go @@ -0,0 +1,34 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONMarshaller_Normal(t *testing.T) { + jm := &openai.JSONMarshaller{} + data := map[string]string{"key": "value"} + + b, err := jm.Marshal(data) + checks.NoError(t, err) + if len(b) == 0 { + t.Fatal("should return non-empty bytes") + } +} + +func TestJSONMarshaller_InvalidInput(t *testing.T) { + jm := &openai.JSONMarshaller{} + _, err := jm.Marshal(make(chan int)) + checks.HasError(t, err, "should return error for unsupported type") +} + +func TestJSONMarshaller_EmptyValue(t *testing.T) { + jm := &openai.JSONMarshaller{} + b, err := jm.Marshal(nil) + checks.NoError(t, err) + if string(b) != "null" { + t.Fatalf("unexpected marshaled value: %s", string(b)) + } +} diff --git a/internal/request_builder.go b/internal/request_builder.go new file mode 100644 index 000000000..5699f6b18 --- /dev/null +++ b/internal/request_builder.go @@ -0,0 +1,52 @@ +package openai + +import ( + "bytes" + "context" + "io" + "net/http" +) + +type RequestBuilder interface { + Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error) +} + +type HTTPRequestBuilder struct { + marshaller Marshaller +} + +func NewRequestBuilder() *HTTPRequestBuilder { + return &HTTPRequestBuilder{ + marshaller: &JSONMarshaller{}, + } +} + +func (b *HTTPRequestBuilder) Build( + ctx context.Context, + method string, + url string, + body any, + header http.Header, +) (req *http.Request, err error) { + var bodyReader io.Reader + if body != nil { + if v, ok := body.(io.Reader); ok { + bodyReader = v + } else { + var reqBytes []byte + reqBytes, err = b.marshaller.Marshal(body) + if err != nil { + return + } + bodyReader = bytes.NewBuffer(reqBytes) + } + } + req, err = http.NewRequestWithContext(ctx, method, url, bodyReader) + if err != nil { + return + } + if header != nil { + req.Header = header + } + return +} diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go new file mode 100644 index 000000000..adccb158e --- /dev/null +++ b/internal/request_builder_test.go @@ -0,0 +1,96 @@ +package openai //nolint:testpackage // testing private field + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "reflect" + "testing" +) + +var errTestMarshallerFailed = errors.New("test marshaller failed") + +type failingMarshaller struct{} + +func (*failingMarshaller) Marshal(_ any) ([]byte, error) { + return []byte{}, errTestMarshallerFailed +} + +func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { + builder := HTTPRequestBuilder{ + marshaller: &failingMarshaller{}, + } + + _, err := builder.Build(context.Background(), "", "", struct{}{}, nil) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } +} + +func TestRequestBuilderReturnsRequest(t *testing.T) { + b := NewRequestBuilder() + var ( + ctx = context.Background() + method = http.MethodPost + url = "/foo" + request = map[string]string{"foo": "bar"} + reqBytes, _ = b.marshaller.Marshal(request) + want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes)) + ) + got, _ := b.Build(ctx, method, url, request, nil) + if !reflect.DeepEqual(got.Body, want.Body) || + !reflect.DeepEqual(got.URL, want.URL) || + !reflect.DeepEqual(got.Method, want.Method) { + t.Errorf("Build() got = %v, want %v", got, want) + } +} + +func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { + var ( + ctx = context.Background() + method = http.MethodGet + url = "/foo" + want, _ = http.NewRequestWithContext(ctx, method, url, nil) + ) + b := NewRequestBuilder() + got, _ := b.Build(ctx, method, url, nil, nil) + if !reflect.DeepEqual(got, want) { + t.Errorf("Build() got = %v, want %v", got, want) + } +} + +func TestRequestBuilderWithReaderBodyAndHeader(t *testing.T) { + b := NewRequestBuilder() + ctx := context.Background() + method := http.MethodPost + url := "/reader" + bodyContent := "hello" + body := bytes.NewBufferString(bodyContent) + header := http.Header{"X-Test": []string{"val"}} + + req, err := b.Build(ctx, method, url, body, header) + if err != nil { + t.Fatalf("Build returned error: %v", err) + } + + gotBody, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("cannot read body: %v", err) + } + if string(gotBody) != bodyContent { + t.Fatalf("expected body %q, got %q", bodyContent, string(gotBody)) + } + if req.Header.Get("X-Test") != "val" { + t.Fatalf("expected header set to val, got %q", req.Header.Get("X-Test")) + } +} + +func TestRequestBuilderInvalidURL(t *testing.T) { + b := NewRequestBuilder() + _, err := b.Build(context.Background(), http.MethodGet, ":", nil, nil) + if err == nil { + t.Fatal("expected error for invalid URL") + } +} diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 713369157..6bd0964c6 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -12,6 +12,13 @@ func NoError(t *testing.T, err error, message ...string) { } } +func NoErrorF(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Fatal(err, message) + } +} + func HasError(t *testing.T, err error, message ...string) { t.Helper() if err == nil { diff --git a/internal/test/checks/checks_test.go b/internal/test/checks/checks_test.go new file mode 100644 index 000000000..0677054df --- /dev/null +++ b/internal/test/checks/checks_test.go @@ -0,0 +1,19 @@ +package checks_test + +import ( + "errors" + "testing" + + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestChecksSuccessPaths(t *testing.T) { + checks.NoError(t, nil) + checks.NoErrorF(t, nil) + checks.HasError(t, errors.New("err")) + target := errors.New("x") + checks.ErrorIs(t, target, target) + checks.ErrorIsF(t, target, target, "msg") + checks.ErrorIsNot(t, errors.New("y"), target) + checks.ErrorIsNotf(t, errors.New("y"), target, "msg") +} diff --git a/internal/test/failer.go b/internal/test/failer.go new file mode 100644 index 000000000..10ca64e34 --- /dev/null +++ b/internal/test/failer.go @@ -0,0 +1,21 @@ +package test + +import "errors" + +var ( + ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") +) + +type FailingErrorBuffer struct{} + +func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) { + return 0, ErrTestErrorAccumulatorWriteFailed +} + +func (b *FailingErrorBuffer) Len() int { + return 0 +} + +func (b *FailingErrorBuffer) Bytes() []byte { + return []byte{} +} diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go new file mode 100644 index 000000000..fb1f4bf06 --- /dev/null +++ b/internal/test/failer_test.go @@ -0,0 +1,24 @@ +//nolint:testpackage // need access to unexported fields and types for testing +package test + +import ( + "errors" + "testing" +) + +func TestFailingErrorBuffer(t *testing.T) { + buf := &FailingErrorBuffer{} + n, err := buf.Write([]byte("test")) + if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed) { + t.Fatalf("expected %v, got %v", ErrTestErrorAccumulatorWriteFailed, err) + } + if n != 0 { + t.Fatalf("expected n=0, got %d", n) + } + if buf.Len() != 0 { + t.Fatalf("expected Len 0, got %d", buf.Len()) + } + if len(buf.Bytes()) != 0 { + t.Fatalf("expected empty bytes") + } +} diff --git a/internal/test/helpers.go b/internal/test/helpers.go index 8461e5374..dc5fa6646 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -3,6 +3,7 @@ package test import ( "github.com/sashabaranov/go-openai/internal/test/checks" + "net/http" "os" "testing" ) @@ -18,12 +19,25 @@ func CreateTestFile(t *testing.T, path string) { file.Close() } -// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called. -func CreateTestDirectory(t *testing.T) (path string, cleanup func()) { - t.Helper() - - path, err := os.MkdirTemp(os.TempDir(), "") - checks.NoError(t, err) +// TokenRoundTripper is a struct that implements the RoundTripper +// interface, specifically to handle the authentication token by adding a token +// to the request header. We need this because the API requires that each +// request include a valid API token in the headers for authentication and +// authorization. +type TokenRoundTripper struct { + Token string + Fallback http.RoundTripper +} - return path, func() { os.RemoveAll(path) } +// RoundTrip takes an *http.Request as input and returns an +// *http.Response and an error. +// +// It is expected to use the provided request to create a connection to an HTTP +// server and return the response, or an error if one occurred. The returned +// Response should have its Body closed. If the RoundTrip method returns an +// error, the Client's Get, Head, Post, and PostForm methods return the same +// error. +func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+t.Token) + return t.Fallback.RoundTrip(req) } diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go new file mode 100644 index 000000000..aa177679b --- /dev/null +++ b/internal/test/helpers_test.go @@ -0,0 +1,54 @@ +package test_test + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestCreateTestFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "file.txt") + internaltest.CreateTestFile(t, path) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read created file: %v", err) + } + if string(data) != "hello" { + t.Fatalf("unexpected file contents: %q", string(data)) + } +} + +func TestTokenRoundTripperAddsHeader(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+internaltest.GetTestToken() { + t.Fatalf("authorization header not set") + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := srv.Client() + client.Transport = &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: client.Transport} + + req, err := http.NewRequest(http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatalf("request error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client request error: %v", err) + } + if _, err = io.Copy(io.Discard, resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } +} diff --git a/internal/test/server.go b/internal/test/server.go index 79d55c405..d32c3e4cb 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -4,6 +4,8 @@ import ( "log" "net/http" "net/http/httptest" + "regexp" + "strings" ) const testAPI = "this-is-my-secure-token-do-not-steal!!" @@ -21,14 +23,29 @@ func NewTestServer() *ServerTest { return &ServerTest{handlers: make(map[string]handler)} } +// HandlerCount returns the number of registered handlers. +func (ts *ServerTest) HandlerCount() int { + return len(ts.handlers) +} + +// HasHandler checks if a handler was registered for the given path. +func (ts *ServerTest) HasHandler(path string) bool { + path = strings.ReplaceAll(path, "*", ".*") + _, ok := ts.handlers[path] + return ok +} + func (ts *ServerTest) RegisterHandler(path string, handler handler) { + // to make the registered paths friendlier to a regex match in the route handler + // in OpenAITestServer + path = strings.ReplaceAll(path, "*", ".*") ts.handlers[path] = handler } // OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. func (ts *ServerTest) OpenAITestServer() *httptest.Server { return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("received request at path %q\n", r.URL.Path) + log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path) // check auth if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() { @@ -36,11 +53,16 @@ func (ts *ServerTest) OpenAITestServer() *httptest.Server { return } - handlerCall, ok := ts.handlers[r.URL.Path] - if !ok { - http.Error(w, "the resource path doesn't exist", http.StatusNotFound) - return + // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling + for route, handler := range ts.handlers { + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") + if pattern.MatchString(r.URL.Path) { + handler(w, r) + return + } } - handlerCall(w, r) + http.Error(w, "the resource path doesn't exist", http.StatusNotFound) })) } diff --git a/internal/test/server_test.go b/internal/test/server_test.go new file mode 100644 index 000000000..f8ce731d1 --- /dev/null +++ b/internal/test/server_test.go @@ -0,0 +1,80 @@ +package test_test + +import ( + "io" + "net/http" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestGetTestToken(t *testing.T) { + if internaltest.GetTestToken() != "this-is-my-secure-token-do-not-steal!!" { + t.Fatalf("unexpected token") + } +} + +func TestNewTestServer(t *testing.T) { + ts := internaltest.NewTestServer() + if ts == nil { + t.Fatalf("server not properly initialized") + } + if ts.HandlerCount() != 0 { + t.Fatalf("expected no handlers initially") + } +} + +func TestRegisterHandlerTransformsPath(t *testing.T) { + ts := internaltest.NewTestServer() + h := func(_ http.ResponseWriter, _ *http.Request) {} + ts.RegisterHandler("/foo/*", h) + if !ts.HasHandler("/foo/*") { + t.Fatalf("handler not registered with transformed path") + } +} + +func TestOpenAITestServer(t *testing.T) { + ts := internaltest.NewTestServer() + ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, _ *http.Request) { + if _, err := io.WriteString(w, "ok"); err != nil { + t.Fatalf("write: %v", err) + } + }) + srv := ts.OpenAITestServer() + srv.Start() + defer srv.Close() + + base := srv.Client().Transport + client := &http.Client{Transport: &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: base}} + resp, err := client.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatalf("read response body: %v", err) + } + if resp.StatusCode != http.StatusOK || string(body) != "ok" { + t.Fatalf("unexpected response: %d %q", resp.StatusCode, string(body)) + } + + // unregistered path + resp, err = client.Get(srv.URL + "/unknown") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d", resp.StatusCode) + } + + // missing token should return unauthorized + clientNoToken := srv.Client() + resp, err = clientNoToken.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } +} diff --git a/internal/unmarshaler.go b/internal/unmarshaler.go new file mode 100644 index 000000000..882876022 --- /dev/null +++ b/internal/unmarshaler.go @@ -0,0 +1,15 @@ +package openai + +import ( + "encoding/json" +) + +type Unmarshaler interface { + Unmarshal(data []byte, v any) error +} + +type JSONUnmarshaler struct{} + +func (jm *JSONUnmarshaler) Unmarshal(data []byte, v any) error { + return json.Unmarshal(data, v) +} diff --git a/internal/unmarshaler_test.go b/internal/unmarshaler_test.go new file mode 100644 index 000000000..d63eac779 --- /dev/null +++ b/internal/unmarshaler_test.go @@ -0,0 +1,37 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONUnmarshaler_Normal(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{"key":"value"}`) + var v map[string]string + + err := jm.Unmarshal(data, &v) + checks.NoError(t, err) + if v["key"] != "value" { + t.Fatal("unmarshal result mismatch") + } +} + +func TestJSONUnmarshaler_InvalidJSON(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{invalid}`) + var v map[string]interface{} + + err := jm.Unmarshal(data, &v) + checks.HasError(t, err, "should return error for invalid JSON") +} + +func TestJSONUnmarshaler_EmptyInput(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + var v interface{} + + err := jm.Unmarshal(nil, &v) + checks.HasError(t, err, "should return error for nil input") +} diff --git a/jsonschema/containsref_test.go b/jsonschema/containsref_test.go new file mode 100644 index 000000000..dc1842775 --- /dev/null +++ b/jsonschema/containsref_test.go @@ -0,0 +1,48 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// SelfRef struct used to produce a self-referential schema. +type SelfRef struct { + Friends []SelfRef `json:"friends"` +} + +// Address struct referenced by Person without self-reference. +type Address struct { + Street string `json:"street"` +} + +type Person struct { + Address Address `json:"address"` +} + +// TestGenerateSchemaForType_SelfRef ensures that self-referential types are not +// flattened during schema generation. +func TestGenerateSchemaForType_SelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(SelfRef{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["SelfRef"]; !ok { + t.Fatal("expected defs to contain SelfRef for self reference") + } +} + +// TestGenerateSchemaForType_NoSelfRef ensures that non-self-referential types +// are flattened and do not reappear in $defs. +func TestGenerateSchemaForType_NoSelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(Person{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["Person"]; ok { + t.Fatal("unexpected Person definition in defs") + } + if _, ok := schema.Defs["Address"]; !ok { + t.Fatal("expected Address definition in defs") + } +} diff --git a/jsonschema/json.go b/jsonschema/json.go new file mode 100644 index 000000000..75e3b5173 --- /dev/null +++ b/jsonschema/json.go @@ -0,0 +1,235 @@ +// Package jsonschema provides very simple functionality for representing a JSON schema as a +// (nested) struct. This struct can be used with the chat completion "function call" feature. +// For more complicated schemas, it is recommended to use a dedicated JSON schema library +// and/or pass in the schema in []byte format. +package jsonschema + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" +) + +type DataType string + +const ( + Object DataType = "object" + Number DataType = "number" + Integer DataType = "integer" + String DataType = "string" + Array DataType = "array" + Null DataType = "null" + Boolean DataType = "boolean" +) + +// Definition is a struct for describing a JSON Schema. +// It is fairly limited, and you may have better luck using a third-party library. +type Definition struct { + // Type specifies the data type of the schema. + Type DataType `json:"type,omitempty"` + // Description is the description of the schema. + Description string `json:"description,omitempty"` + // Enum is used to restrict a value to a fixed set of values. It must be an array with at least + // one element, where each element is unique. You will probably only use this with strings. + Enum []string `json:"enum,omitempty"` + // Properties describes the properties of an object, if the schema type is Object. + Properties map[string]Definition `json:"properties,omitempty"` + // Required specifies which properties are required, if the schema type is Object. + Required []string `json:"required,omitempty"` + // Items specifies which data type an array contains, if the schema type is Array. + Items *Definition `json:"items,omitempty"` + // AdditionalProperties is used to control the handling of properties in an object + // that are not explicitly defined in the properties section of the schema. example: + // additionalProperties: true + // additionalProperties: false + // additionalProperties: jsonschema.Definition{Type: jsonschema.String} + AdditionalProperties any `json:"additionalProperties,omitempty"` + // Whether the schema is nullable or not. + Nullable bool `json:"nullable,omitempty"` + + // Ref Reference to a definition in $defs or external schema. + Ref string `json:"$ref,omitempty"` + // Defs A map of reusable schema definitions. + Defs map[string]Definition `json:"$defs,omitempty"` +} + +func (d *Definition) MarshalJSON() ([]byte, error) { + if d.Properties == nil { + d.Properties = make(map[string]Definition) + } + type Alias Definition + return json.Marshal(struct { + Alias + }{ + Alias: (Alias)(*d), + }) +} + +func (d *Definition) Unmarshal(content string, v any) error { + return VerifySchemaAndUnmarshal(*d, []byte(content), v) +} + +func GenerateSchemaForType(v any) (*Definition, error) { + var defs = make(map[string]Definition) + def, err := reflectSchema(reflect.TypeOf(v), defs) + if err != nil { + return nil, err + } + // If the schema has a root $ref, resolve it by: + // 1. Extracting the key from the $ref. + // 2. Detaching the referenced definition from $defs. + // 3. Checking for self-references in the detached definition. + // - If a self-reference is found, restore the original $defs structure. + // 4. Flattening the referenced definition into the root schema. + // 5. Clearing the $ref field in the root schema. + if def.Ref != "" { + origRef := def.Ref + key := strings.TrimPrefix(origRef, "#/$defs/") + if root, ok := defs[key]; ok { + delete(defs, key) + root.Defs = defs + if containsRef(root, origRef) { + root.Defs = nil + defs[key] = root + } + *def = root + } + def.Ref = "" + } + def.Defs = defs + return def, nil +} + +func reflectSchema(t reflect.Type, defs map[string]Definition) (*Definition, error) { + var d Definition + switch t.Kind() { + case reflect.String: + d.Type = String + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + d.Type = Integer + case reflect.Float32, reflect.Float64: + d.Type = Number + case reflect.Bool: + d.Type = Boolean + case reflect.Slice, reflect.Array: + d.Type = Array + items, err := reflectSchema(t.Elem(), defs) + if err != nil { + return nil, err + } + d.Items = items + case reflect.Struct: + if t.Name() != "" { + if _, ok := defs[t.Name()]; !ok { + defs[t.Name()] = Definition{} + object, err := reflectSchemaObject(t, defs) + if err != nil { + return nil, err + } + defs[t.Name()] = *object + } + return &Definition{Ref: "#/$defs/" + t.Name()}, nil + } + d.Type = Object + d.AdditionalProperties = false + object, err := reflectSchemaObject(t, defs) + if err != nil { + return nil, err + } + d = *object + case reflect.Ptr: + definition, err := reflectSchema(t.Elem(), defs) + if err != nil { + return nil, err + } + d = *definition + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, + reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.UnsafePointer: + return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) + default: + } + return &d, nil +} + +func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definition, error) { + var d = Definition{ + Type: Object, + AdditionalProperties: false, + } + properties := make(map[string]Definition) + var requiredFields []string + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + jsonTag := field.Tag.Get("json") + var required = true + switch { + case jsonTag == "-": + continue + case jsonTag == "": + jsonTag = field.Name + case strings.HasSuffix(jsonTag, ",omitempty"): + jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") + required = false + } + + item, err := reflectSchema(field.Type, defs) + if err != nil { + return nil, err + } + description := field.Tag.Get("description") + if description != "" { + item.Description = description + } + enum := field.Tag.Get("enum") + if enum != "" { + item.Enum = strings.Split(enum, ",") + } + + if n := field.Tag.Get("nullable"); n != "" { + nullable, _ := strconv.ParseBool(n) + item.Nullable = nullable + } + + properties[jsonTag] = *item + + if s := field.Tag.Get("required"); s != "" { + required, _ = strconv.ParseBool(s) + } + if required { + requiredFields = append(requiredFields, jsonTag) + } + } + d.Required = requiredFields + d.Properties = properties + return &d, nil +} + +func containsRef(def Definition, targetRef string) bool { + if def.Ref == targetRef { + return true + } + + for _, d := range def.Defs { + if containsRef(d, targetRef) { + return true + } + } + + for _, prop := range def.Properties { + if containsRef(prop, targetRef) { + return true + } + } + + if def.Items != nil && containsRef(*def.Items, targetRef) { + return true + } + return false +} diff --git a/jsonschema/json_additional_test.go b/jsonschema/json_additional_test.go new file mode 100644 index 000000000..70cf37490 --- /dev/null +++ b/jsonschema/json_additional_test.go @@ -0,0 +1,73 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// Test Definition.Unmarshal, including success path, validation error, +// JSON syntax error and type mismatch during unmarshalling. +func TestDefinitionUnmarshal(t *testing.T) { + schema := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + }, + } + + var dst struct { + Name string `json:"name"` + } + if err := schema.Unmarshal(`{"name":"foo"}`, &dst); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dst.Name != "foo" { + t.Errorf("expected name to be foo, got %q", dst.Name) + } + + if err := schema.Unmarshal(`{`, &dst); err == nil { + t.Error("expected error for malformed json") + } + + if err := schema.Unmarshal(`{"name":1}`, &dst); err == nil { + t.Error("expected validation error") + } + + numSchema := jsonschema.Definition{Type: jsonschema.Number} + var s string + if err := numSchema.Unmarshal(`123`, &s); err == nil { + t.Error("expected unmarshal type error") + } +} + +// Ensure GenerateSchemaForType returns an error when encountering unsupported types. +func TestGenerateSchemaForTypeUnsupported(t *testing.T) { + type Bad struct { + Ch chan int `json:"ch"` + } + if _, err := jsonschema.GenerateSchemaForType(Bad{}); err == nil { + t.Fatal("expected error for unsupported type") + } +} + +// Validate should fail when provided data does not match the expected container types. +func TestValidateInvalidContainers(t *testing.T) { + objSchema := jsonschema.Definition{Type: jsonschema.Object} + if jsonschema.Validate(objSchema, 1) { + t.Error("expected object validation to fail for non-map input") + } + + arrSchema := jsonschema.Definition{Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}} + if jsonschema.Validate(arrSchema, 1) { + t.Error("expected array validation to fail for non-slice input") + } +} + +// Validate should return false when $ref cannot be resolved. +func TestValidateRefNotFound(t *testing.T) { + refSchema := jsonschema.Definition{Ref: "#/$defs/Missing"} + if jsonschema.Validate(refSchema, "data", jsonschema.WithDefs(map[string]jsonschema.Definition{})) { + t.Error("expected validation to fail when reference is missing") + } +} diff --git a/jsonschema/json_errors_test.go b/jsonschema/json_errors_test.go new file mode 100644 index 000000000..3b864fc21 --- /dev/null +++ b/jsonschema/json_errors_test.go @@ -0,0 +1,27 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// TestGenerateSchemaForType_ErrorPaths verifies error handling for unsupported types. +func TestGenerateSchemaForType_ErrorPaths(t *testing.T) { + type anon struct{ Ch chan int } + tests := []struct { + name string + v any + }{ + {"slice", []chan int{}}, + {"anon struct", anon{}}, + {"pointer", (*chan int)(nil)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := jsonschema.GenerateSchemaForType(tt.v); err == nil { + t.Errorf("expected error for %s", tt.name) + } + }) + } +} diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go new file mode 100644 index 000000000..34f5d88eb --- /dev/null +++ b/jsonschema/json_test.go @@ -0,0 +1,670 @@ +package jsonschema_test + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +func TestDefinition_MarshalJSON(t *testing.T) { + tests := []struct { + name string + def jsonschema.Definition + want string + }{ + { + name: "Test with empty Definition", + def: jsonschema.Definition{}, + want: `{}`, + }, + { + name: "Test with Definition properties set", + def: jsonschema.Definition{ + Type: jsonschema.String, + Description: "A string type", + Properties: map[string]jsonschema.Definition{ + "name": { + Type: jsonschema.String, + }, + }, + }, + want: `{ + "type":"string", + "description":"A string type", + "properties":{ + "name":{ + "type":"string" + } + } + }`, + }, + { + name: "Test with nested Definition properties", + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "user": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": { + Type: jsonschema.String, + }, + "age": { + Type: jsonschema.Integer, + }, + }, + }, + }, + }, + want: `{ + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + } + } + } + } + }`, + }, + { + name: "Test with complex nested Definition", + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "user": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": { + Type: jsonschema.String, + }, + "age": { + Type: jsonschema.Integer, + }, + "address": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "city": { + Type: jsonschema.String, + }, + "country": { + Type: jsonschema.String, + }, + }, + }, + }, + }, + }, + }, + want: `{ + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + }, + "address":{ + "type":"object", + "properties":{ + "city":{ + "type":"string" + }, + "country":{ + "type":"string" + } + } + } + } + } + } + }`, + }, + { + name: "Test with Array type Definition", + def: jsonschema.Definition{ + Type: jsonschema.Array, + Items: &jsonschema.Definition{ + Type: jsonschema.String, + }, + Properties: map[string]jsonschema.Definition{ + "name": { + Type: jsonschema.String, + }, + }, + }, + want: `{ + "type":"array", + "items":{ + "type":"string" + }, + "properties":{ + "name":{ + "type":"string" + } + } + }`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wantBytes := []byte(tt.want) + var want map[string]interface{} + err := json.Unmarshal(wantBytes, &want) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + got := structToMap(t, tt.def) + gotPtr := structToMap(t, &tt.def) + + if !reflect.DeepEqual(got, want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, want) + } + if !reflect.DeepEqual(gotPtr, want) { + t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want) + } + }) + } +} + +type User struct { + ID int `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Orders []*Order `json:"orders,omitempty"` +} + +type Order struct { + ID int `json:"id,omitempty"` + Amount float64 `json:"amount,omitempty"` + Buyer *User `json:"buyer,omitempty"` +} + +func TestStructToSchema(t *testing.T) { + type Tweet struct { + Text string `json:"text"` + } + + type Person struct { + Name string `json:"name,omitempty"` + Age int `json:"age,omitempty"` + Friends []Person `json:"friends,omitempty"` + Tweets []Tweet `json:"tweets,omitempty"` + } + + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + + tests := []struct { + name string + in any + want string + }{ + { + name: "Test with empty struct", + in: struct{}{}, + want: `{ + "type":"object", + "additionalProperties":false + }`, + }, + { + name: "Test with struct containing many fields", + in: struct { + Name string `json:"name"` + Age int `json:"age"` + Active bool `json:"active"` + Height float64 `json:"height"` + Cities []struct { + Name string `json:"name"` + State string `json:"state"` + } `json:"cities"` + }{ + Name: "John Doe", + Age: 30, + Cities: []struct { + Name string `json:"name"` + State string `json:"state"` + }{ + {Name: "New York", State: "NY"}, + {Name: "Los Angeles", State: "CA"}, + }, + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + }, + "active":{ + "type":"boolean" + }, + "height":{ + "type":"number" + }, + "cities":{ + "type":"array", + "items":{ + "additionalProperties":false, + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "state":{ + "type":"string" + } + }, + "required":["name","state"] + } + } + }, + "required":["name","age","active","height","cities"], + "additionalProperties":false + }`, + }, + { + name: "Test with description tag", + in: struct { + Name string `json:"name" description:"The name of the person"` + }{ + Name: "John Doe", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "description":"The name of the person" + } + }, + "required":["name"], + "additionalProperties":false + }`, + }, + { + name: "Test with required tag", + in: struct { + Name string `json:"name" required:"false"` + }{ + Name: "John Doe", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + } + }, + "additionalProperties":false + }`, + }, + { + name: "Test with enum tag", + in: struct { + Color string `json:"color" enum:"red,green,blue"` + }{ + Color: "red", + }, + want: `{ + "type":"object", + "properties":{ + "color":{ + "type":"string", + "enum":["red","green","blue"] + } + }, + "required":["color"], + "additionalProperties":false + }`, + }, + { + name: "Test with nullable tag", + in: struct { + Name *string `json:"name" nullable:"true"` + }{ + Name: nil, + }, + want: `{ + + "type":"object", + "properties":{ + "name":{ + "type":"string", + "nullable":true + } + }, + "required":["name"], + "additionalProperties":false + }`, + }, + { + name: "Test with exclude mark", + in: struct { + Name string `json:"-"` + }{ + Name: "Name", + }, + want: `{ + "type":"object", + "additionalProperties":false + }`, + }, + { + name: "Test with no json tag", + in: struct { + Name string + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "Name":{ + "type":"string" + } + }, + "required":["Name"], + "additionalProperties":false + }`, + }, + { + name: "Test with omitempty tag", + in: struct { + Name string `json:"name,omitempty"` + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + } + }, + "additionalProperties":false + }`, + }, + { + name: "Test with $ref and $defs", + in: struct { + Person Person `json:"person"` + Tweets []Tweet `json:"tweets"` + }{}, + want: `{ + "type" : "object", + "properties" : { + "person" : { + "$ref" : "#/$defs/Person" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "required" : [ "person", "tweets" ], + "additionalProperties" : false, + "$defs" : { + "Person" : { + "type" : "object", + "properties" : { + "age" : { + "type" : "integer" + }, + "friends" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Person" + } + }, + "name" : { + "type" : "string" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "additionalProperties" : false + }, + "Tweet" : { + "type" : "object", + "properties" : { + "text" : { + "type" : "string" + } + }, + "required" : [ "text" ], + "additionalProperties" : false + } + } +}`, + }, + { + name: "Test Person", + in: Person{}, + want: `{ + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false, + "$defs": { + "Person": { + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false + }, + "Tweet": { + "type": "object", + "properties": { + "text": { + "type": "string" + } + }, + "required": [ + "text" + ], + "additionalProperties": false + } + } +}`, + }, + { + name: "Test MyStructuredResponse", + in: MyStructuredResponse{}, + want: `{ + "type": "object", + "properties": { + "camel_case": { + "type": "string", + "description": "CamelCase" + }, + "kebab_case": { + "type": "string", + "description": "KebabCase" + }, + "pascal_case": { + "type": "string", + "description": "PascalCase" + }, + "snake_case": { + "type": "string", + "description": "SnakeCase" + } + }, + "required": [ + "pascal_case", + "camel_case", + "kebab_case", + "snake_case" + ], + "additionalProperties": false +}`, + }, + { + name: "Test User", + in: User{}, + want: `{ + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false, + "$defs": { + "Order": { + "type": "object", + "properties": { + "amount": { + "type": "number" + }, + "buyer": { + "$ref": "#/$defs/User" + }, + "id": { + "type": "integer" + } + }, + "additionalProperties": false + }, + "User": { + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false + } + } +}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wantBytes := []byte(tt.want) + + schema, err := jsonschema.GenerateSchemaForType(tt.in) + if err != nil { + t.Errorf("Failed to generate schema: error = %v", err) + return + } + + var want map[string]interface{} + err = json.Unmarshal(wantBytes, &want) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + got := structToMap(t, schema) + gotPtr := structToMap(t, &schema) + + if !reflect.DeepEqual(got, want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, want) + } + if !reflect.DeepEqual(gotPtr, want) { + t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want) + } + }) + } +} + +func structToMap(t *testing.T, v any) map[string]any { + t.Helper() + gotBytes, err := json.Marshal(v) + if err != nil { + t.Errorf("Failed to Marshal JSON: error = %v", err) + return nil + } + + var got map[string]interface{} + err = json.Unmarshal(gotBytes, &got) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return nil + } + return got +} diff --git a/jsonschema/validate.go b/jsonschema/validate.go new file mode 100644 index 000000000..1bd2f809c --- /dev/null +++ b/jsonschema/validate.go @@ -0,0 +1,140 @@ +package jsonschema + +import ( + "encoding/json" + "errors" +) + +func CollectDefs(def Definition) map[string]Definition { + result := make(map[string]Definition) + collectDefsRecursive(def, result, "#") + return result +} + +func collectDefsRecursive(def Definition, result map[string]Definition, prefix string) { + for k, v := range def.Defs { + path := prefix + "/$defs/" + k + result[path] = v + collectDefsRecursive(v, result, path) + } + for k, sub := range def.Properties { + collectDefsRecursive(sub, result, prefix+"/properties/"+k) + } + if def.Items != nil { + collectDefsRecursive(*def.Items, result, prefix) + } +} + +func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { + var data any + err := json.Unmarshal(content, &data) + if err != nil { + return err + } + if !Validate(schema, data, WithDefs(CollectDefs(schema))) { + return errors.New("data validation failed against the provided schema") + } + return json.Unmarshal(content, &v) +} + +type validateArgs struct { + Defs map[string]Definition +} + +type ValidateOption func(*validateArgs) + +func WithDefs(defs map[string]Definition) ValidateOption { + return func(option *validateArgs) { + option.Defs = defs + } +} + +func Validate(schema Definition, data any, opts ...ValidateOption) bool { + args := validateArgs{} + for _, opt := range opts { + opt(&args) + } + if len(opts) == 0 { + args.Defs = CollectDefs(schema) + } + switch schema.Type { + case Object: + return validateObject(schema, data, args.Defs) + case Array: + return validateArray(schema, data, args.Defs) + case String: + v, ok := data.(string) + if ok && len(schema.Enum) > 0 { + return contains(schema.Enum, v) + } + return ok + case Number: // float64 and int + _, ok := data.(float64) + if !ok { + _, ok = data.(int) + } + return ok + case Boolean: + _, ok := data.(bool) + return ok + case Integer: + // Golang unmarshals all numbers as float64, so we need to check if the float64 is an integer + if num, ok := data.(float64); ok { + return num == float64(int64(num)) + } + _, ok := data.(int) + return ok + case Null: + return data == nil + default: + if schema.Ref != "" && args.Defs != nil { + if v, ok := args.Defs[schema.Ref]; ok { + return Validate(v, data, WithDefs(args.Defs)) + } + } + return false + } +} + +func validateObject(schema Definition, data any, defs map[string]Definition) bool { + dataMap, ok := data.(map[string]any) + if !ok { + return false + } + for _, field := range schema.Required { + if _, exists := dataMap[field]; !exists { + return false + } + } + for key, valueSchema := range schema.Properties { + value, exists := dataMap[key] + if exists && !Validate(valueSchema, value, WithDefs(defs)) { + return false + } else if !exists && contains(schema.Required, key) { + return false + } + } + return true +} + +func validateArray(schema Definition, data any, defs map[string]Definition) bool { + dataArray, ok := data.([]any) + if !ok { + return false + } + for _, item := range dataArray { + if !Validate(*schema.Items, item, WithDefs(defs)) { + return false + } + } + return true +} + +func contains[S ~[]E, E comparable](s S, v E) bool { + for i := range s { + if v == s[i] { + return true + } + } + return false +} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go new file mode 100644 index 000000000..aefdf4069 --- /dev/null +++ b/jsonschema/validate_test.go @@ -0,0 +1,347 @@ +package jsonschema_test + +import ( + "reflect" + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +func Test_Validate(t *testing.T) { + type args struct { + data any + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want bool + }{ + // string integer number boolean + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.String}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.String}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Integer}}, true}, + {"", args{data: 123.4, schema: jsonschema.Definition{Type: jsonschema.Integer}}, false}, + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.Number}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Number}}, true}, + {"", args{data: false, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, false}, + {"", args{data: nil, schema: jsonschema.Definition{Type: jsonschema.Null}}, true}, + {"", args{data: 0, schema: jsonschema.Definition{Type: jsonschema.Null}}, false}, + // array + {"", args{data: []any{"a", "b", "c"}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, true}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, false}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, true}, + {"", args{data: []any{1, 2, 3.4}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, false}, + // object + {"", args{data: map[string]any{ + "string": "abc", + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, true}, + {"", args{data: map[string]any{ + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, false}, + { + "test schema with ref and defs", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "male", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, true}, + { + "test enum invalid value", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "other", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := jsonschema.Validate(tt.args.schema, tt.args.data); got != tt.want { + t.Errorf("Validate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshal(t *testing.T) { + type args struct { + schema jsonschema.Definition + content []byte + v any + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + }, + content: []byte(`{"string":"abc","number":123.4}`), + v: &struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, + }, false}, + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + Required: []string{"string", "number"}, + }, + content: []byte(`{"string":"abc"}`), + v: struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, + }, true}, + {"validate integer", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, + }, false}, + {"validate integer failed", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123.4}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, + }, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCollectDefs(t *testing.T) { + type args struct { + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want map[string]jsonschema.Definition + }{ + { + "test collect defs", + args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + }, + map[string]jsonschema.Definition{ + "#/$defs/Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "#/$defs/Person/$defs/Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + "#/$defs/Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := jsonschema.CollectDefs(tt.args.schema) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CollectDefs() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/marshaller.go b/marshaller.go deleted file mode 100644 index 308ccd154..000000000 --- a/marshaller.go +++ /dev/null @@ -1,15 +0,0 @@ -package openai - -import ( - "encoding/json" -) - -type marshaller interface { - marshal(value any) ([]byte, error) -} - -type jsonMarshaller struct{} - -func (jm *jsonMarshaller) marshal(value any) ([]byte, error) { - return json.Marshal(value) -} diff --git a/messages.go b/messages.go new file mode 100644 index 000000000..3852d2e37 --- /dev/null +++ b/messages.go @@ -0,0 +1,224 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + messagesSuffix = "messages" +) + +type Message struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + ThreadID string `json:"thread_id"` + Role string `json:"role"` + Content []MessageContent `json:"content"` + FileIds []string `json:"file_ids"` //nolint:revive //backwards-compatibility + AssistantID *string `json:"assistant_id,omitempty"` + RunID *string `json:"run_id,omitempty"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type MessagesList struct { + Messages []Message `json:"data"` + + Object string `json:"object"` + FirstID *string `json:"first_id"` + LastID *string `json:"last_id"` + HasMore bool `json:"has_more"` + + httpHeader +} + +type MessageContent struct { + Type string `json:"type"` + Text *MessageText `json:"text,omitempty"` + ImageFile *ImageFile `json:"image_file,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} +type MessageText struct { + Value string `json:"value"` + Annotations []any `json:"annotations"` +} + +type ImageFile struct { + FileID string `json:"file_id"` +} + +type ImageURL struct { + URL string `json:"url"` + Detail string `json:"detail"` +} + +type MessageRequest struct { + Role string `json:"role"` + Content string `json:"content"` + FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` +} + +type MessageFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + MessageID string `json:"message_id"` + + httpHeader +} + +type MessageFilesList struct { + MessageFiles []MessageFile `json:"data"` + + httpHeader +} + +type MessageDeletionStatus struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +// CreateMessage creates a new message. +func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// ListMessage fetches all messages in the thread. +func (c *Client) ListMessage(ctx context.Context, threadID string, + limit *int, + order *string, + after *string, + before *string, + runID *string, +) (messages MessagesList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + if runID != nil { + urlValues.Add("run_id", *runID) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &messages) + return +} + +// RetrieveMessage retrieves a Message. +func (c *Client) RetrieveMessage( + ctx context.Context, + threadID, messageID string, +) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// ModifyMessage modifies a message. +func (c *Client) ModifyMessage( + ctx context.Context, + threadID, messageID string, + metadata map[string]string, +) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(map[string]any{"metadata": metadata}), withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// RetrieveMessageFile fetches a message file. +func (c *Client) RetrieveMessageFile( + ctx context.Context, + threadID, messageID, fileID string, +) (file MessageFile, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &file) + return +} + +// ListMessageFiles fetches all files attached to a message. +func (c *Client) ListMessageFiles( + ctx context.Context, + threadID, messageID string, +) (files MessageFilesList, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &files) + return +} + +// DeleteMessage deletes a message.. +func (c *Client) DeleteMessage( + ctx context.Context, + threadID, messageID string, +) (status MessageDeletionStatus, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &status) + return +} diff --git a/messages_test.go b/messages_test.go new file mode 100644 index 000000000..b25755f98 --- /dev/null +++ b/messages_test.go @@ -0,0 +1,272 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +var emptyStr = "" + +func setupServerForTestMessage(t *testing.T, server *test.ServerTest) { + threadID := "thread_abc123" + messageID := "msg_abc123" + fileID := "file_abc123" + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.MessageFile{ + ID: fileID, + Object: "thread.message.file", + CreatedAt: 1699061776, + MessageID: messageID, + }) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID+"/files", + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.MessageFilesList{MessageFiles: []openai.MessageFile{{ + ID: fileID, + Object: "thread.message.file", + CreatedAt: 0, + MessageID: messageID, + }}}) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + metadata := map[string]any{} + err := json.NewDecoder(r.Body).Decode(&metadata) + checks.NoError(t, err, "unable to decode metadata in modify message call") + payload, ok := metadata["metadata"].(map[string]any) + if !ok { + t.Fatalf("metadata payload improperly wrapped %+v", metadata) + } + + resBytes, _ := json.Marshal( + openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: payload, + }) + + fmt.Fprintln(w, string(resBytes)) + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + resBytes, _ := json.Marshal(openai.MessageDeletionStatus{ + ID: messageID, + Object: "thread.message.deleted", + Deleted: true, + }) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages", + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + resBytes, _ := json.Marshal(openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodGet: + resBytes, _ := json.Marshal(openai.MessagesList{ + Object: "list", + Messages: []openai.Message{{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }}, + FirstID: &messageID, + LastID: &messageID, + HasMore: false, + }) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) +} + +// TestMessages Tests the messages endpoint of the API using the mocked server. +func TestMessages(t *testing.T) { + threadID := "thread_abc123" + messageID := "msg_abc123" + fileID := "file_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + setupServerForTestMessage(t, server) + ctx := context.Background() + + // static assertion of return type + var msg openai.Message + msg, err := client.CreateMessage(ctx, threadID, openai.MessageRequest{ + Role: "user", + Content: "How does AI work?", + FileIds: nil, + Metadata: nil, + }) + checks.NoError(t, err, "CreateMessage error") + if msg.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + + var msgs openai.MessagesList + msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil, nil) + checks.NoError(t, err, "ListMessages error") + if len(msgs.Messages) != 1 { + t.Fatalf("unexpected length of fetched messages") + } + + // with pagination options set + limit := 1 + order := "desc" + after := "obj_foo" + before := "obj_bar" + runID := "run_abc123" + msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before, &runID) + checks.NoError(t, err, "ListMessages error") + if len(msgs.Messages) != 1 { + t.Fatalf("unexpected length of fetched messages") + } + + msg, err = client.RetrieveMessage(ctx, threadID, messageID) + checks.NoError(t, err, "RetrieveMessage error") + if msg.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + + msg, err = client.ModifyMessage(ctx, threadID, messageID, + map[string]string{ + "foo": "bar", + }) + checks.NoError(t, err, "ModifyMessage error") + if msg.Metadata["foo"] != "bar" { + t.Fatalf("expected message metadata to get modified") + } + + msgDel, err := client.DeleteMessage(ctx, threadID, messageID) + checks.NoError(t, err, "DeleteMessage error") + if msgDel.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + if !msgDel.Deleted { + t.Fatalf("expected deleted is true") + } + _, err = client.DeleteMessage(ctx, threadID, "not_exist_id") + checks.HasError(t, err, "DeleteMessage error") + + // message files + var msgFile openai.MessageFile + msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID) + checks.NoError(t, err, "RetrieveMessageFile error") + if msgFile.ID != fileID { + t.Fatalf("unexpected message file id: '%s'", msgFile.ID) + } + + var msgFiles openai.MessageFilesList + msgFiles, err = client.ListMessageFiles(ctx, threadID, messageID) + checks.NoError(t, err, "RetrieveMessageFile error") + if len(msgFiles.MessageFiles) != 1 { + t.Fatalf("unexpected count of message files: %d", len(msgFiles.MessageFiles)) + } + if msgFiles.MessageFiles[0].ID != fileID { + t.Fatalf("unexpected message file id: '%s' in list message files", msgFiles.MessageFiles[0].ID) + } +} diff --git a/models.go b/models.go index 2be91aadb..d94f98836 100644 --- a/models.go +++ b/models.go @@ -2,6 +2,7 @@ package openai import ( "context" + "fmt" "net/http" ) @@ -14,6 +15,8 @@ type Model struct { Permission []Permission `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` + + httpHeader } // Permission struct represents an OpenAPI permission. @@ -32,15 +35,26 @@ type Permission struct { IsBlocking bool `json:"is_blocking"` } +// FineTuneModelDeleteResponse represents the deletion status of a fine-tuned model. +type FineTuneModelDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` + + httpHeader } // ListModels Lists the currently available models, // and provides basic information about each model such as the model id and parent. func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/models")) if err != nil { return } @@ -48,3 +62,29 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) err = c.sendRequest(req, &models) return } + +// GetModel Retrieves a model instance, providing basic information about +// the model such as the owner and permissioning. +func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) { + urlSuffix := fmt.Sprintf("/models/%s", modelID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &model) + return +} + +// DeleteFineTuneModel Deletes a fine-tune model. You must have the Owner +// role in your organization to delete a model. +func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( + response FineTuneModelDeleteResponse, err error) { + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/models/"+modelID)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/models_test.go b/models_test.go index 70d6d756c..7fd010c34 100644 --- a/models_test.go +++ b/models_test.go @@ -1,56 +1,112 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" + "os" "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) -// TestListModels Tests the models endpoint of the API using the mocked server. -func TestListModels(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/models", handleModelsEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() +const testFineTuneModelID = "fine-tune-model-id" - _, err = client.ListModels(ctx) +// TestListModels Tests the list models endpoint of the API using the mocked server. +func TestListModels(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models", handleListModelsEndpoint) + _, err := client.ListModels(context.Background()) checks.NoError(t, err, "ListModels error") } func TestAzureListModels(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/openai/models", handleModelsEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/", "dummyengine") - config.BaseURL = ts.URL - client := NewClientWithConfig(config) + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/models", handleListModelsEndpoint) + _, err := client.ListModels(context.Background()) + checks.NoError(t, err, "ListModels error") +} + +// handleListModelsEndpoint Handles the list models endpoint by the test server. +func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.ModelsList{}) + fmt.Fprintln(w, string(resBytes)) +} + +// TestGetModel Tests the retrieve model endpoint of the API using the mocked server. +func TestGetModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "text-davinci-003") + checks.NoError(t, err, "GetModel error") +} + +// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server. +func TestGetModelO3(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o3") + checks.NoError(t, err, "GetModel error for O3") +} + +// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server. +func TestGetModelO4Mini(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o4-mini") + checks.NoError(t, err, "GetModel error for O4Mini") +} + +func TestAzureGetModel(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/models/text-davinci-003", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "text-davinci-003") + checks.NoError(t, err, "GetModel error") +} + +// handleGetModelsEndpoint Handles the get model endpoint by the test server. +func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.Model{}) + fmt.Fprintln(w, string(resBytes)) +} + +func TestGetModelReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/text-davinci-003", func(http.ResponseWriter, *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() - _, err = client.ListModels(ctx) - checks.NoError(t, err, "ListModels error") + _, err := client.GetModel(ctx, "text-davinci-003") + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} + +func TestDeleteFineTuneModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/"+testFineTuneModelID, handleDeleteFineTuneModelEndpoint) + _, err := client.DeleteFineTuneModel(context.Background(), testFineTuneModelID) + checks.NoError(t, err, "DeleteFineTuneModel error") } -// handleModelsEndpoint Handles the models endpoint by the test server. -func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(ModelsList{}) +func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuneModelDeleteResponse{}) fmt.Fprintln(w, string(resBytes)) } diff --git a/moderation.go b/moderation.go index b386ddb95..a0e09c0ee 100644 --- a/moderation.go +++ b/moderation.go @@ -2,6 +2,7 @@ package openai import ( "context" + "errors" "net/http" ) @@ -13,11 +14,25 @@ import ( // If you use text-moderation-stable, we will provide advanced notice before updating the model. // Accuracy of text-moderation-stable may be slightly lower than for text-moderation-latest. const ( - ModerationTextStable = "text-moderation-stable" - ModerationTextLatest = "text-moderation-latest" - ModerationText001 = "text-moderation-001" + ModerationOmniLatest = "omni-moderation-latest" + ModerationOmni20240926 = "omni-moderation-2024-09-26" + ModerationTextStable = "text-moderation-stable" + ModerationTextLatest = "text-moderation-latest" + // Deprecated: use ModerationTextStable and ModerationTextLatest instead. + ModerationText001 = "text-moderation-001" ) +var ( + ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll +) + +var validModerationModel = map[string]struct{}{ + ModerationOmniLatest: {}, + ModerationOmni20240926: {}, + ModerationTextStable: {}, + ModerationTextLatest: {}, +} + // ModerationRequest represents a request structure for moderation API. type ModerationRequest struct { Input string `json:"input,omitempty"` @@ -33,24 +48,32 @@ type Result struct { // ResultCategories represents Categories of Result. type ResultCategories struct { - Hate bool `json:"hate"` - HateThreatening bool `json:"hate/threatening"` - SelfHarm bool `json:"self-harm"` - Sexual bool `json:"sexual"` - SexualMinors bool `json:"sexual/minors"` - Violence bool `json:"violence"` - ViolenceGraphic bool `json:"violence/graphic"` + Hate bool `json:"hate"` + HateThreatening bool `json:"hate/threatening"` + Harassment bool `json:"harassment"` + HarassmentThreatening bool `json:"harassment/threatening"` + SelfHarm bool `json:"self-harm"` + SelfHarmIntent bool `json:"self-harm/intent"` + SelfHarmInstructions bool `json:"self-harm/instructions"` + Sexual bool `json:"sexual"` + SexualMinors bool `json:"sexual/minors"` + Violence bool `json:"violence"` + ViolenceGraphic bool `json:"violence/graphic"` } // ResultCategoryScores represents CategoryScores of Result. type ResultCategoryScores struct { - Hate float32 `json:"hate"` - HateThreatening float32 `json:"hate/threatening"` - SelfHarm float32 `json:"self-harm"` - Sexual float32 `json:"sexual"` - SexualMinors float32 `json:"sexual/minors"` - Violence float32 `json:"violence"` - ViolenceGraphic float32 `json:"violence/graphic"` + Hate float32 `json:"hate"` + HateThreatening float32 `json:"hate/threatening"` + Harassment float32 `json:"harassment"` + HarassmentThreatening float32 `json:"harassment/threatening"` + SelfHarm float32 `json:"self-harm"` + SelfHarmIntent float32 `json:"self-harm/intent"` + SelfHarmInstructions float32 `json:"self-harm/instructions"` + Sexual float32 `json:"sexual"` + SexualMinors float32 `json:"sexual/minors"` + Violence float32 `json:"violence"` + ViolenceGraphic float32 `json:"violence/graphic"` } // ModerationResponse represents a response structure for moderation API. @@ -58,12 +81,23 @@ type ModerationResponse struct { ID string `json:"id"` Model string `json:"model"` Results []Result `json:"results"` + + httpHeader } // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request) + if _, ok := validModerationModel[request.Model]; len(request.Model) > 0 && !ok { + err = ErrModerationInvalidModel + return + } + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/moderations", withModel(request.Model)), + withBody(&request), + ) if err != nil { return } diff --git a/moderation_test.go b/moderation_test.go index 2c1145627..a97f25bc6 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -1,10 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -14,33 +10,60 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestModeration Tests the moderations endpoint of the API using the mocked server. func TestModerations(t *testing.T) { - server := test.NewTestServer() + client, server, teardown := setupOpenAITestServer() + defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) - // create the test server - var err error - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - ctx := context.Background() - - // create an edit request - model := "text-moderation-stable" - moderationReq := ModerationRequest{ - Model: model, + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ + Model: openai.ModerationTextStable, Input: "I want to kill them.", - } - _, err = client.Moderations(ctx, moderationReq) + }) checks.NoError(t, err, "Moderation error") } +// TestModerationsWithIncorrectModel Tests passing valid and invalid models to moderations endpoint. +func TestModerationsWithDifferentModelOptions(t *testing.T) { + var modelOptions []struct { + model string + expect error + } + modelOptions = append(modelOptions, + getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), + getModerationModelTestOption(openai.ModerationTextStable, nil), + getModerationModelTestOption(openai.ModerationTextLatest, nil), + getModerationModelTestOption(openai.ModerationOmni20240926, nil), + getModerationModelTestOption(openai.ModerationOmniLatest, nil), + getModerationModelTestOption("", nil), + ) + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/moderations", handleModerationEndpoint) + for _, modelTest := range modelOptions { + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ + Model: modelTest.model, + Input: "I want to kill them.", + }) + checks.ErrorIs(t, err, modelTest.expect, + fmt.Sprintf("Moderations(..) expects err: %v, actual err:%v", modelTest.expect, err)) + } +} + +func getModerationModelTestOption(model string, expect error) struct { + model string + expect error +} { + return struct { + model string + expect error + }{model: model, expect: expect} +} + // handleModerationEndpoint Handles the moderation endpoint by the test server. func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { var err error @@ -50,32 +73,63 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var moderationReq ModerationRequest + var moderationReq openai.ModerationRequest if moderationReq, err = getModerationBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - resCat := ResultCategories{} - resCatScore := ResultCategoryScores{} + resCat := openai.ResultCategories{} + resCatScore := openai.ResultCategoryScores{} switch { - case strings.Contains(moderationReq.Input, "kill"): - resCat = ResultCategories{Violence: true} - resCatScore = ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "hate"): - resCat = ResultCategories{Hate: true} - resCatScore = ResultCategoryScores{Hate: 1} + resCat = openai.ResultCategories{Hate: true} + resCatScore = openai.ResultCategoryScores{Hate: 1} + + case strings.Contains(moderationReq.Input, "hate more"): + resCat = openai.ResultCategories{HateThreatening: true} + resCatScore = openai.ResultCategoryScores{HateThreatening: 1} + + case strings.Contains(moderationReq.Input, "harass"): + resCat = openai.ResultCategories{Harassment: true} + resCatScore = openai.ResultCategoryScores{Harassment: 1} + + case strings.Contains(moderationReq.Input, "harass hard"): + resCat = openai.ResultCategories{Harassment: true} + resCatScore = openai.ResultCategoryScores{HarassmentThreatening: 1} + case strings.Contains(moderationReq.Input, "suicide"): - resCat = ResultCategories{SelfHarm: true} - resCatScore = ResultCategoryScores{SelfHarm: 1} + resCat = openai.ResultCategories{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} + + case strings.Contains(moderationReq.Input, "wanna suicide"): + resCat = openai.ResultCategories{SelfHarmIntent: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} + + case strings.Contains(moderationReq.Input, "drink bleach"): + resCat = openai.ResultCategories{SelfHarmInstructions: true} + resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: 1} + case strings.Contains(moderationReq.Input, "porn"): - resCat = ResultCategories{Sexual: true} - resCatScore = ResultCategoryScores{Sexual: 1} + resCat = openai.ResultCategories{Sexual: true} + resCatScore = openai.ResultCategoryScores{Sexual: 1} + + case strings.Contains(moderationReq.Input, "child porn"): + resCat = openai.ResultCategories{SexualMinors: true} + resCatScore = openai.ResultCategoryScores{SexualMinors: 1} + + case strings.Contains(moderationReq.Input, "kill"): + resCat = openai.ResultCategories{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: 1} + + case strings.Contains(moderationReq.Input, "corpse"): + resCat = openai.ResultCategories{ViolenceGraphic: true} + resCatScore = openai.ResultCategoryScores{ViolenceGraphic: 1} } - result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} + result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} - res := ModerationResponse{ + res := openai.ModerationResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Model: moderationReq.Model, } @@ -86,16 +140,16 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { } // getModerationBody Returns the body of the request to do a moderation. -func getModerationBody(r *http.Request) (ModerationRequest, error) { - moderation := ModerationRequest{} +func getModerationBody(r *http.Request) (openai.ModerationRequest, error) { + moderation := openai.ModerationRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } err = json.Unmarshal(reqBody, &moderation) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } return moderation, nil } diff --git a/openai_test.go b/openai_test.go new file mode 100644 index 000000000..a55f3a858 --- /dev/null +++ b/openai_test.go @@ -0,0 +1,37 @@ +package openai_test + +import ( + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" +) + +func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { + server = test.NewTestServer() + ts := server.OpenAITestServer() + ts.Start() + teardown = ts.Close + config := openai.DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client = openai.NewClientWithConfig(config) + return +} + +func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { + server = test.NewTestServer() + ts := server.OpenAITestServer() + ts.Start() + teardown = ts.Close + config := openai.DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") + config.BaseURL = ts.URL + client = openai.NewClientWithConfig(config) + return +} + +// numTokens Returns the number of GPT-3 encoded tokens in the given text. +// This function approximates based on the rule of thumb stated by OpenAI: +// https://beta.openai.com/tokenizer. +// +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available). +func numTokens(s string) int { + return int(float32(len(s)) / 4) +} diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 000000000..e8953f716 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,43 @@ +package openai + +import ( + "net/http" + "strconv" + "time" +) + +// RateLimitHeaders struct represents Openai rate limits headers. +type RateLimitHeaders struct { + LimitRequests int `json:"x-ratelimit-limit-requests"` + LimitTokens int `json:"x-ratelimit-limit-tokens"` + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` +} + +type ResetTime string + +func (r ResetTime) String() string { + return string(r) +} + +func (r ResetTime) Time() time.Time { + d, _ := time.ParseDuration(string(r)) + return time.Now().Add(d) +} + +func newRateLimitHeaders(h http.Header) RateLimitHeaders { + limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) + limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) + remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) + remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) + return RateLimitHeaders{ + LimitRequests: limitReq, + LimitTokens: limitTokens, + RemainingRequests: remainingReq, + RemainingTokens: remainingTokens, + ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), + ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), + } +} diff --git a/reasoning_validator.go b/reasoning_validator.go new file mode 100644 index 000000000..1d26ca047 --- /dev/null +++ b/reasoning_validator.go @@ -0,0 +1,82 @@ +package openai + +import ( + "errors" + "strings" +) + +var ( + // Deprecated: use ErrReasoningModelMaxTokensDeprecated instead. + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll + ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll + ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll + ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll +) + +var ( + ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll + ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll + // Deprecated: use ErrReasoningModelLimitations* instead. + ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +var ( + //nolint:lll + ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") + ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +// ReasoningValidator handles validation for reasoning model requests. +type ReasoningValidator struct{} + +// NewReasoningValidator creates a new validator for reasoning models. +func NewReasoningValidator() *ReasoningValidator { + return &ReasoningValidator{} +} + +// Validate performs all validation checks for reasoning models. +func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { + o1Series := strings.HasPrefix(request.Model, "o1") + o3Series := strings.HasPrefix(request.Model, "o3") + o4Series := strings.HasPrefix(request.Model, "o4") + gpt5Series := strings.HasPrefix(request.Model, "gpt-5") + + if !o1Series && !o3Series && !o4Series && !gpt5Series { + return nil + } + + if err := v.validateReasoningModelParams(request); err != nil { + return err + } + + return nil +} + +// validateReasoningModelParams checks reasoning model parameters. +func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error { + if request.MaxTokens > 0 { + return ErrReasoningModelMaxTokensDeprecated + } + if request.LogProbs { + return ErrReasoningModelLimitationsLogprobs + } + if request.Temperature > 0 && request.Temperature != 1 { + return ErrReasoningModelLimitationsOther + } + if request.TopP > 0 && request.TopP != 1 { + return ErrReasoningModelLimitationsOther + } + if request.N > 0 && request.N != 1 { + return ErrReasoningModelLimitationsOther + } + if request.PresencePenalty > 0 { + return ErrReasoningModelLimitationsOther + } + if request.FrequencyPenalty > 0 { + return ErrReasoningModelLimitationsOther + } + + return nil +} diff --git a/request_builder.go b/request_builder.go deleted file mode 100644 index f0cef10fe..000000000 --- a/request_builder.go +++ /dev/null @@ -1,40 +0,0 @@ -package openai - -import ( - "bytes" - "context" - "net/http" -) - -type requestBuilder interface { - build(ctx context.Context, method, url string, request any) (*http.Request, error) -} - -type httpRequestBuilder struct { - marshaller marshaller -} - -func newRequestBuilder() *httpRequestBuilder { - return &httpRequestBuilder{ - marshaller: &jsonMarshaller{}, - } -} - -func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) { - if request == nil { - return http.NewRequestWithContext(ctx, method, url, nil) - } - - var reqBytes []byte - reqBytes, err := b.marshaller.marshal(request) - if err != nil { - return nil, err - } - - return http.NewRequestWithContext( - ctx, - method, - url, - bytes.NewBuffer(reqBytes), - ) -} diff --git a/request_builder_test.go b/request_builder_test.go deleted file mode 100644 index b1adbf1c6..000000000 --- a/request_builder_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "errors" - "net/http" - "testing" -) - -var ( - errTestMarshallerFailed = errors.New("test marshaller failed") - errTestRequestBuilderFailed = errors.New("test request builder failed") -) - -type ( - failingRequestBuilder struct{} - failingMarshaller struct{} -) - -func (*failingMarshaller) marshal(_ any) ([]byte, error) { - return []byte{}, errTestMarshallerFailed -} - -func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) { - return nil, errTestRequestBuilderFailed -} - -func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { - builder := httpRequestBuilder{ - marshaller: &failingMarshaller{}, - } - - _, err := builder.build(context.Background(), "", "", struct{}{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } -} - -func TestClientReturnsRequestBuilderErrors(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - client.requestBuilder = &failingRequestBuilder{} - - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTunes(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CancelFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.DeleteFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTuneEvents(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Moderations(ctx, ModerationRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Edits(ctx, EditsRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateImage(ctx, ImageRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - err = client.DeleteFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFiles(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListEngines(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetEngine(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListModels(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } -} - -func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - client.requestBuilder = &failingRequestBuilder{} - - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) - if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) - if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } -} diff --git a/run.go b/run.go new file mode 100644 index 000000000..9c51aaf8d --- /dev/null +++ b/run.go @@ -0,0 +1,454 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type Run struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + ThreadID string `json:"thread_id"` + AssistantID string `json:"assistant_id"` + Status RunStatus `json:"status"` + RequiredAction *RunRequiredAction `json:"required_action,omitempty"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiresAt int64 `json:"expires_at"` + StartedAt *int64 `json:"started_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Model string `json:"model"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools"` + FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata"` + Usage Usage `json:"usage,omitempty"` + + Temperature *float32 `json:"temperature,omitempty"` + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + + httpHeader +} + +type RunStatus string + +const ( + RunStatusQueued RunStatus = "queued" + RunStatusInProgress RunStatus = "in_progress" + RunStatusRequiresAction RunStatus = "requires_action" + RunStatusCancelling RunStatus = "cancelling" + RunStatusFailed RunStatus = "failed" + RunStatusCompleted RunStatus = "completed" + RunStatusIncomplete RunStatus = "incomplete" + RunStatusExpired RunStatus = "expired" + RunStatusCancelled RunStatus = "cancelled" +) + +type RunRequiredAction struct { + Type RequiredActionType `json:"type"` + SubmitToolOutputs *SubmitToolOutputs `json:"submit_tool_outputs,omitempty"` +} + +type RequiredActionType string + +const ( + RequiredActionTypeSubmitToolOutputs RequiredActionType = "submit_tool_outputs" +) + +type SubmitToolOutputs struct { + ToolCalls []ToolCall `json:"tool_calls"` +} + +type RunLastError struct { + Code RunError `json:"code"` + Message string `json:"message"` +} + +type RunError string + +const ( + RunErrorServerError RunError = "server_error" + RunErrorRateLimitExceeded RunError = "rate_limit_exceeded" +) + +type RunRequest struct { + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + + // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. + // lower values are more focused and deterministic. + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + + // This can be either a string or a ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // This can be either a string or a ResponseFormat object. + ResponseFormat any `json:"response_format,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` +} + +// ThreadTruncationStrategy defines the truncation strategy to use for the thread. +// https://platform.openai.com/docs/assistants/how-it-works/truncation-strategy. +type ThreadTruncationStrategy struct { + // default 'auto'. + Type TruncationStrategy `json:"type,omitempty"` + // this field should be set if the truncation strategy is set to LastMessages. + LastMessages *int `json:"last_messages,omitempty"` +} + +// TruncationStrategy defines the existing truncation strategies existing for thread management in an assistant. +type TruncationStrategy string + +const ( + // TruncationStrategyAuto messages in the middle of the thread will be dropped to fit the context length of the model. + TruncationStrategyAuto = TruncationStrategy("auto") + // TruncationStrategyLastMessages the thread will be truncated to the n most recent messages in the thread. + TruncationStrategyLastMessages = TruncationStrategy("last_messages") +) + +// ReponseFormat specifies the format the model must output. +// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format. +// Type can either be text or json_object. +type ReponseFormat struct { + Type string `json:"type"` +} + +type RunModifyRequest struct { + Metadata map[string]any `json:"metadata,omitempty"` +} + +// RunList is a list of runs. +type RunList struct { + Runs []Run `json:"data"` + + httpHeader +} + +type SubmitToolOutputsRequest struct { + ToolOutputs []ToolOutput `json:"tool_outputs"` +} + +type ToolOutput struct { + ToolCallID string `json:"tool_call_id"` + Output any `json:"output"` +} + +type CreateThreadAndRunRequest struct { + RunRequest + Thread ThreadRequest `json:"thread"` +} + +type RunStep struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` + ThreadID string `json:"thread_id"` + RunID string `json:"run_id"` + Type RunStepType `json:"type"` + Status RunStepStatus `json:"status"` + StepDetails StepDetails `json:"step_details"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiredAt *int64 `json:"expired_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type RunStepStatus string + +const ( + RunStepStatusInProgress RunStepStatus = "in_progress" + RunStepStatusCancelling RunStepStatus = "cancelled" + RunStepStatusFailed RunStepStatus = "failed" + RunStepStatusCompleted RunStepStatus = "completed" + RunStepStatusExpired RunStepStatus = "expired" +) + +type RunStepType string + +const ( + RunStepTypeMessageCreation RunStepType = "message_creation" + RunStepTypeToolCalls RunStepType = "tool_calls" +) + +type StepDetails struct { + Type RunStepType `json:"type"` + MessageCreation *StepDetailsMessageCreation `json:"message_creation,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type StepDetailsMessageCreation struct { + MessageID string `json:"message_id"` +} + +// RunStepList is a list of steps. +type RunStepList struct { + RunSteps []RunStep `json:"data"` + + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` + + httpHeader +} + +type Pagination struct { + Limit *int + Order *string + After *string + Before *string +} + +// CreateRun creates a new run. +func (c *Client) CreateRun( + ctx context.Context, + threadID string, + request RunRequest, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveRun retrieves a run. +func (c *Client) RetrieveRun( + ctx context.Context, + threadID string, + runID string, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyRun modifies a run. +func (c *Client) ModifyRun( + ctx context.Context, + threadID string, + runID string, + request RunModifyRequest, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListRuns lists runs. +func (c *Client) ListRuns( + ctx context.Context, + threadID string, + pagination Pagination, +) (response RunList, err error) { + urlValues := url.Values{} + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/runs%s", threadID, encodedValues) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// SubmitToolOutputs submits tool outputs. +func (c *Client) SubmitToolOutputs( + ctx context.Context, + threadID string, + runID string, + request SubmitToolOutputsRequest) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelRun cancels a run. +func (c *Client) CancelRun( + ctx context.Context, + threadID string, + runID string) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/cancel", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CreateThreadAndRun submits tool outputs. +func (c *Client) CreateThreadAndRun( + ctx context.Context, + request CreateThreadAndRunRequest) (response Run, err error) { + urlSuffix := "/threads/runs" + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveRunStep retrieves a run step. +func (c *Client) RetrieveRunStep( + ctx context.Context, + threadID string, + runID string, + stepID string, +) (response RunStep, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/steps/%s", threadID, runID, stepID) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListRunSteps lists run steps. +func (c *Client) ListRunSteps( + ctx context.Context, + threadID string, + runID string, + pagination Pagination, +) (response RunStepList, err error) { + urlValues := url.Values{} + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/steps%s", threadID, runID, encodedValues) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/run_test.go b/run_test.go new file mode 100644 index 000000000..cdf99db05 --- /dev/null +++ b/run_test.go @@ -0,0 +1,237 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestAssistant Tests the assistant endpoint of the API using the mocked server. +func TestRun(t *testing.T) { + assistantID := "asst_abc123" + threadID := "thread_abc123" + runID := "run_abc123" + stepID := "step_abc123" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/steps/"+stepID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunStep{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStepStatusCompleted, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/steps", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunStepList{ + RunSteps: []openai.RunStep{ + { + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStepStatusCompleted, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusCancelling, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/submit_tool_outputs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusCancelling, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.RunModifyRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.RunRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunList{ + Runs: []openai.Run{ + { + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/runs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.CreateThreadAndRunRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateRun(ctx, threadID, openai.RunRequest{ + AssistantID: assistantID, + }) + checks.NoError(t, err, "CreateRun error") + + _, err = client.RetrieveRun(ctx, threadID, runID) + checks.NoError(t, err, "RetrieveRun error") + + _, err = client.ModifyRun(ctx, threadID, runID, openai.RunModifyRequest{ + Metadata: map[string]any{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyRun error") + + _, err = client.ListRuns( + ctx, + threadID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }, + ) + checks.NoError(t, err, "ListRuns error") + + _, err = client.SubmitToolOutputs(ctx, threadID, runID, + openai.SubmitToolOutputsRequest{}) + checks.NoError(t, err, "SubmitToolOutputs error") + + _, err = client.CancelRun(ctx, threadID, runID) + checks.NoError(t, err, "CancelRun error") + + _, err = client.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{ + RunRequest: openai.RunRequest{ + AssistantID: assistantID, + }, + Thread: openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }, + }) + checks.NoError(t, err, "CreateThreadAndRun error") + + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) + checks.NoError(t, err, "RetrieveRunStep error") + + _, err = client.ListRunSteps( + ctx, + threadID, + runID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }, + ) + checks.NoError(t, err, "ListRunSteps error") +} diff --git a/speech.go b/speech.go new file mode 100644 index 000000000..60e7694fd --- /dev/null +++ b/speech.go @@ -0,0 +1,65 @@ +package openai + +import ( + "context" + "net/http" +) + +type SpeechModel string + +const ( + TTSModel1 SpeechModel = "tts-1" + TTSModel1HD SpeechModel = "tts-1-hd" + TTSModelCanary SpeechModel = "canary-tts" + TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts" +) + +type SpeechVoice string + +const ( + VoiceAlloy SpeechVoice = "alloy" + VoiceAsh SpeechVoice = "ash" + VoiceBallad SpeechVoice = "ballad" + VoiceCoral SpeechVoice = "coral" + VoiceEcho SpeechVoice = "echo" + VoiceFable SpeechVoice = "fable" + VoiceOnyx SpeechVoice = "onyx" + VoiceNova SpeechVoice = "nova" + VoiceShimmer SpeechVoice = "shimmer" + VoiceVerse SpeechVoice = "verse" +) + +type SpeechResponseFormat string + +const ( + SpeechResponseFormatMp3 SpeechResponseFormat = "mp3" + SpeechResponseFormatOpus SpeechResponseFormat = "opus" + SpeechResponseFormatAac SpeechResponseFormat = "aac" + SpeechResponseFormatFlac SpeechResponseFormat = "flac" + SpeechResponseFormatWav SpeechResponseFormat = "wav" + SpeechResponseFormatPcm SpeechResponseFormat = "pcm" +) + +type CreateSpeechRequest struct { + Model SpeechModel `json:"model"` + Input string `json:"input"` + Voice SpeechVoice `json:"voice"` + Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd. + ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 + Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 +} + +func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/audio/speech", withModel(string(request.Model))), + withBody(request), + withContentType("application/json"), + ) + if err != nil { + return + } + + return c.sendRequestRaw(req) +} diff --git a/speech_test.go b/speech_test.go new file mode 100644 index 000000000..67a3feabc --- /dev/null +++ b/speech_test.go @@ -0,0 +1,96 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestSpeechIntegration(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { + path := filepath.Join(t.TempDir(), "fake.mp3") + test.CreateTestFile(t, path) + + // audio endpoints only accept POST requests + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, "failed to parse media type", http.StatusBadRequest) + return + } + + if mediaType != "application/json" { + http.Error(w, "request is not json", http.StatusBadRequest) + return + } + + // Parse the JSON body of the request + var params map[string]interface{} + err = json.NewDecoder(r.Body).Decode(¶ms) + if err != nil { + http.Error(w, "failed to parse request body", http.StatusBadRequest) + return + } + + // Check if each required field is present in the parsed JSON object + reqParams := []string{"model", "input", "voice"} + for _, param := range reqParams { + _, ok := params[param] + if !ok { + http.Error(w, fmt.Sprintf("no %s in params", param), http.StatusBadRequest) + return + } + } + + // read audio file content + audioFile, err := os.ReadFile(path) + if err != nil { + http.Error(w, "failed to read audio file", http.StatusInternalServerError) + return + } + + // write audio file content to response + w.Header().Set("Content-Type", "audio/mpeg") + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Connection", "keep-alive") + _, err = w.Write(audioFile) + if err != nil { + http.Error(w, "failed to write body", http.StatusInternalServerError) + return + } + }) + + t.Run("happy path", func(t *testing.T) { + res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: openai.TTSModel1, + Input: "Hello!", + Voice: openai.VoiceAlloy, + }) + checks.NoError(t, err, "CreateSpeech error") + defer res.Close() + + buf, err := io.ReadAll(res) + checks.NoError(t, err, "ReadAll error") + + // save buf to file as mp3 + err = os.WriteFile("test.mp3", buf, 0644) + checks.NoError(t, err, "Create error") + }) +} diff --git a/stream.go b/stream.go index 95662db6d..a61c7c970 100644 --- a/stream.go +++ b/stream.go @@ -1,7 +1,6 @@ package openai import ( - "bufio" "context" "errors" "net/http" @@ -35,27 +34,22 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { - return + return nil, err } - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := sendRequestStream[CompletionResponse](c, req) if err != nil { return } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { - return nil, c.handleErrorResp(resp) - } - stream = &CompletionStream{ - streamReader: &streamReader[CompletionResponse]{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, - errAccumulator: newErrorAccumulator(), - unmarshaler: &jsonUnmarshaler{}, - }, + streamReader: resp, } return } diff --git a/stream_reader.go b/stream_reader.go index aa06f00ae..6faefe0a7 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -6,6 +6,14 @@ import ( "fmt" "io" "net/http" + "regexp" + + utils "github.com/sashabaranov/go-openai/internal" +) + +var ( + headerData = regexp.MustCompile(`^data:\s*`) + errorPrefix = regexp.MustCompile(`^data:\s*{"error":`) ) type streamable interface { @@ -18,55 +26,94 @@ type streamReader[T streamable] struct { reader *bufio.Reader response *http.Response - errAccumulator errorAccumulator - unmarshaler unmarshaler + errAccumulator utils.ErrorAccumulator + unmarshaler utils.Unmarshaler + + httpHeader } func (stream *streamReader[T]) Recv() (response T, err error) { - if stream.isFinished { - err = io.EOF + rawLine, err := stream.RecvRaw() + if err != nil { return } - var emptyMessagesCount uint - -waitForData: - line, err := stream.reader.ReadBytes('\n') + err = stream.unmarshaler.Unmarshal(rawLine, &response) if err != nil { - respErr := stream.errAccumulator.unmarshalError() - if respErr != nil { - err = fmt.Errorf("error, %w", respErr.Error) - } return } + return response, nil +} + +func (stream *streamReader[T]) RecvRaw() ([]byte, error) { + if stream.isFinished { + return nil, io.EOF + } - var headerData = []byte("data: ") - line = bytes.TrimSpace(line) - if !bytes.HasPrefix(line, headerData) { - if writeErr := stream.errAccumulator.write(line); writeErr != nil { - err = writeErr - return + return stream.processLines() +} + +//nolint:gocognit +func (stream *streamReader[T]) processLines() ([]byte, error) { + var ( + emptyMessagesCount uint + hasErrorPrefix bool + ) + + for { + rawLine, readErr := stream.reader.ReadBytes('\n') + if readErr != nil || hasErrorPrefix { + respErr := stream.unmarshalError() + if respErr != nil { + return nil, fmt.Errorf("error, %w", respErr.Error) + } + return nil, readErr + } + + noSpaceLine := bytes.TrimSpace(rawLine) + if errorPrefix.Match(noSpaceLine) { + hasErrorPrefix = true + } + if !headerData.Match(noSpaceLine) || hasErrorPrefix { + if hasErrorPrefix { + noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil) + } + writeErr := stream.errAccumulator.Write(noSpaceLine) + if writeErr != nil { + return nil, writeErr + } + emptyMessagesCount++ + if emptyMessagesCount > stream.emptyMessagesLimit { + return nil, ErrTooManyEmptyStreamMessages + } + + continue } - emptyMessagesCount++ - if emptyMessagesCount > stream.emptyMessagesLimit { - err = ErrTooManyEmptyStreamMessages - return + + noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil) + if string(noPrefixLine) == "[DONE]" { + stream.isFinished = true + return nil, io.EOF } - goto waitForData + return noPrefixLine, nil } +} - line = bytes.TrimPrefix(line, headerData) - if string(line) == "[DONE]" { - stream.isFinished = true - err = io.EOF +func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { + errBytes := stream.errAccumulator.Bytes() + if len(errBytes) == 0 { return } - err = stream.unmarshaler.unmarshal(line, &response) + err := stream.unmarshaler.Unmarshal(errBytes, &errResp) + if err != nil { + errResp = nil + } + return } -func (stream *streamReader[T]) Close() { - stream.response.Body.Close() +func (stream *streamReader[T]) Close() error { + return stream.response.Body.Close() } diff --git a/stream_reader_test.go b/stream_reader_test.go new file mode 100644 index 000000000..449a14b43 --- /dev/null +++ b/stream_reader_test.go @@ -0,0 +1,78 @@ +package openai //nolint:testpackage // testing private field + +import ( + "bufio" + "bytes" + "errors" + "testing" + + utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") + +type failingUnMarshaller struct{} + +func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { + return errTestUnmarshalerFailed +} + +func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &failingUnMarshaller{}, + } + + respErr := stream.unmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil with empty buffer: %v", respErr) + } + + err := stream.errAccumulator.Write([]byte("{")) + if err != nil { + t.Fatalf("%+v", err) + } + + respErr = stream.unmarshalError() + if respErr != nil { + t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) + } +} + +func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + emptyMessagesLimit: 3, + reader: bufio.NewReader(bytes.NewReader([]byte("\n\n\n\n"))), + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + } + _, err := stream.Recv() + checks.ErrorIs(t, err, ErrTooManyEmptyStreamMessages, "Did not return error when recv failed", err.Error()) +} + +func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte("\n"))), + errAccumulator: &utils.DefaultErrorAccumulator{ + Buffer: &test.FailingErrorBuffer{}, + }, + unmarshaler: &utils.JSONUnmarshaler{}, + } + _, err := stream.Recv() + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) +} + +func TestStreamReaderRecvRaw(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))), + } + rawLine, err := stream.RecvRaw() + if err != nil { + t.Fatalf("Did not return raw line: %v", err) + } + if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) { + t.Fatalf("Did not return raw line: %v", string(rawLine)) + } +} diff --git a/stream_test.go b/stream_test.go index a5c591fde..9dd95bb5f 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,36 +2,39 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "net/http" - "net/http/httptest" + "os" "testing" + "time" - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletionStream( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } func TestCreateCompletionStream(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -51,45 +54,31 @@ func TestCreateCompletionStream(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []CompletionResponse{ + expectedResponses := []openai.CompletionResponse{ { ID: "1", Object: "completion", Created: 1598069254, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, }, { ID: "2", Object: "completion", Created: 1598069255, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, }, } @@ -115,7 +104,9 @@ func TestCreateCompletionStream(t *testing.T) { } func TestCreateCompletionStreamError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -136,35 +127,21 @@ func TestCreateCompletionStreamError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") - })) - defer server.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = server.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, - } - - client := NewClientWithConfig(config) - ctx := context.Background() + }) - request := CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3TextDavinci003, + Model: openai.GPT3TextDavinci003, Prompt: "Hello!", Stream: true, - } - - stream, err := client.CreateCompletionStream(ctx, request) + }) checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -172,8 +149,9 @@ func TestCreateCompletionStreamError(t *testing.T) { } func TestCreateCompletionStreamRateLimitError(t *testing.T) { - server := test.NewTestServer() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -187,61 +165,169 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - ts := server.OpenAITestServer() - ts.Start() - defer ts.Close() - - // Client portion of the test - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - config.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, - } - client := NewClientWithConfig(config) - ctx := context.Background() - - request := CompletionRequest{ + var apiErr *openai.APIError + _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Ada, + Model: openai.GPT3Babbage002, Prompt: "Hello!", Stream: true, - } - - var apiErr *APIError - _, err := client.CreateCompletionStream(ctx, request) + }) if !errors.As(err, &apiErr) { t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError") } t.Logf("%+v\n", apiErr) } -// A "tokenRoundTripper" is a struct that implements the RoundTripper -// interface, specifically to handle the authentication token by adding a token -// to the request header. We need this because the API requires that each -// request include a valid API token in the headers for authentication and -// authorization. -type tokenRoundTripper struct { - token string - fallback http.RoundTripper +func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Totally 301 empty messages (300 is the limit) + for i := 0; i < 299; i++ { + dataBytes = append(dataBytes, '\n') + } + + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + if !errors.Is(streamErr, openai.ErrTooManyEmptyStreamMessages) { + t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages") + } } -// RoundTrip takes an *http.Request as input and returns an -// *http.Response and an error. -// -// It is expected to use the provided request to create a connection to an HTTP -// server and return the response, or an error if one occurred. The returned -// Response should have its Body closed. If the RoundTrip method returns an -// error, the Client's Get, Head, Post, and PostForm methods return the same -// error. -func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("Authorization", "Bearer "+t.token) - return t.fallback.RoundTrip(req) +func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Stream is terminated without sending "done" message + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("TestCreateCompletionStreamUnexpectedTerminatedError did not return io.EOF") + } +} + +func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Send broken json + dataBytes = append(dataBytes, []byte("event: message\n")...) + data = `{"id":"2","object":"completion","created":1598069255,"model":` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + var syntaxError *json.SyntaxError + if !errors.As(streamErr, &syntaxError) { + t.Errorf("TestCreateCompletionStreamBrokenJSONError did not return json.SyntaxError") + } +} + +func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(http.ResponseWriter, *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.CreateCompletionStream(ctx, openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + }) + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } } // Helper funcs. -func compareResponses(r1, r2 CompletionResponse) bool { +func compareResponses(r1, r2 openai.CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -256,7 +342,7 @@ func compareResponses(r1, r2 CompletionResponse) bool { return true } -func compareResponseChoices(c1, c2 CompletionChoice) bool { +func compareResponseChoices(c1, c2 openai.CompletionChoice) bool { if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { return false } diff --git a/thread.go b/thread.go new file mode 100644 index 000000000..bc08e2bcb --- /dev/null +++ b/thread.go @@ -0,0 +1,171 @@ +package openai + +import ( + "context" + "net/http" +) + +const ( + threadsSuffix = "/threads" +) + +type Thread struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + ToolResources ToolResources `json:"tool_resources,omitempty"` + + httpHeader +} + +type ThreadRequest struct { + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` +} + +type ToolResources struct { + CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResources `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResources struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` +} + +type ToolResourcesRequest struct { + CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResourcesRequest struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResourcesRequest struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` + VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` +} + +type VectorStoreToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` + ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ChunkingStrategy struct { + Type ChunkingStrategyType `json:"type"` + Static *StaticChunkingStrategy `json:"static,omitempty"` +} + +type StaticChunkingStrategy struct { + MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` + ChunkOverlapTokens int `json:"chunk_overlap_tokens"` +} + +type ChunkingStrategyType string + +const ( + ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" + ChunkingStrategyTypeStatic ChunkingStrategyType = "static" +) + +type ModifyThreadRequest struct { + Metadata map[string]any `json:"metadata"` + ToolResources *ToolResources `json:"tool_resources,omitempty"` +} + +type ThreadMessageRole string + +const ( + ThreadMessageRoleAssistant ThreadMessageRole = "assistant" + ThreadMessageRoleUser ThreadMessageRole = "user" +) + +type ThreadMessage struct { + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + FileIDs []string `json:"file_ids,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ThreadAttachment struct { + FileID string `json:"file_id"` + Tools []ThreadAttachmentTool `json:"tools"` +} + +type ThreadAttachmentTool struct { + Type string `json:"type"` +} + +type ThreadDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +// CreateThread creates a new thread. +func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveThread retrieves a thread. +func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyThread modifies a thread. +func (c *Client) ModifyThread( + ctx context.Context, + threadID string, + request ModifyThreadRequest, +) (response Thread, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteThread deletes a thread. +func (c *Client) DeleteThread( + ctx context.Context, + threadID string, +) (response ThreadDeleteResponse, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/thread_test.go b/thread_test.go new file mode 100644 index 000000000..1ac0f3c0e --- /dev/null +++ b/thread_test.go @@ -0,0 +1,178 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestThread Tests the thread endpoint of the API using the mocked server. +func TestThread(t *testing.T) { + threadID := "thread_abc123" + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.ThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/threads", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.ModifyThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateThread(ctx, openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }) + checks.NoError(t, err, "CreateThread error") + + _, err = client.RetrieveThread(ctx, threadID) + checks.NoError(t, err, "RetrieveThread error") + + _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ + Metadata: map[string]interface{}{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyThread error") + + _, err = client.DeleteThread(ctx, threadID) + checks.NoError(t, err, "DeleteThread error") +} + +// TestAzureThread Tests the thread endpoint of the API using the Azure mocked server. +func TestAzureThread(t *testing.T) { + threadID := "thread_abc123" + client, server, teardown := setupAzureTestServer() + defer teardown() + + server.RegisterHandler( + "/openai/threads/"+threadID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.ThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/threads", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.ModifyThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateThread(ctx, openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }) + checks.NoError(t, err, "CreateThread error") + + _, err = client.RetrieveThread(ctx, threadID) + checks.NoError(t, err, "RetrieveThread error") + + _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ + Metadata: map[string]interface{}{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyThread error") + + _, err = client.DeleteThread(ctx, threadID) + checks.NoError(t, err, "DeleteThread error") +} diff --git a/unmarshaler.go b/unmarshaler.go deleted file mode 100644 index 05218f764..000000000 --- a/unmarshaler.go +++ /dev/null @@ -1,15 +0,0 @@ -package openai - -import ( - "encoding/json" -) - -type unmarshaler interface { - unmarshal(data []byte, v any) error -} - -type jsonUnmarshaler struct{} - -func (jm *jsonUnmarshaler) unmarshal(data []byte, v any) error { - return json.Unmarshal(data, v) -} diff --git a/vector_store.go b/vector_store.go new file mode 100644 index 000000000..682bb1cf9 --- /dev/null +++ b/vector_store.go @@ -0,0 +1,348 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + vectorStoresSuffix = "/vector_stores" + vectorStoresFilesSuffix = "/files" + vectorStoresFileBatchesSuffix = "/file_batches" +) + +type VectorStoreFileCount struct { + InProgress int `json:"in_progress"` + Completed int `json:"completed"` + Failed int `json:"failed"` + Cancelled int `json:"cancelled"` + Total int `json:"total"` +} + +type VectorStore struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name string `json:"name"` + UsageBytes int `json:"usage_bytes"` + FileCounts VectorStoreFileCount `json:"file_counts"` + Status string `json:"status"` + ExpiresAfter *VectorStoreExpires `json:"expires_after"` + ExpiresAt *int `json:"expires_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type VectorStoreExpires struct { + Anchor string `json:"anchor"` + Days int `json:"days"` +} + +// VectorStoreRequest provides the vector store request parameters. +type VectorStoreRequest struct { + Name string `json:"name,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + ExpiresAfter *VectorStoreExpires `json:"expires_after,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// VectorStoresList is a list of vector store. +type VectorStoresList struct { + VectorStores []VectorStore `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` + httpHeader +} + +type VectorStoreDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +type VectorStoreFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + UsageBytes int `json:"usage_bytes"` + Status string `json:"status"` + + httpHeader +} + +type VectorStoreFileRequest struct { + FileID string `json:"file_id"` +} + +type VectorStoreFilesList struct { + VectorStoreFiles []VectorStoreFile `json:"data"` + FirstID *string `json:"first_id"` + LastID *string `json:"last_id"` + HasMore bool `json:"has_more"` + + httpHeader +} + +type VectorStoreFileBatch struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + Status string `json:"status"` + FileCounts VectorStoreFileCount `json:"file_counts"` + + httpHeader +} + +type VectorStoreFileBatchRequest struct { + FileIDs []string `json:"file_ids"` +} + +// CreateVectorStore creates a new vector store. +func (c *Client) CreateVectorStore(ctx context.Context, request VectorStoreRequest) (response VectorStore, err error) { + req, _ := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(vectorStoresSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStore retrieves an vector store. +func (c *Client) RetrieveVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ModifyVectorStore modifies a vector store. +func (c *Client) ModifyVectorStore( + ctx context.Context, + vectorStoreID string, + request VectorStoreRequest, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStore deletes an vector store. +func (c *Client) DeleteVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStoreDeleteResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStores Lists the currently available vector store. +func (c *Client) ListVectorStores( + ctx context.Context, + pagination Pagination, +) (response VectorStoresList, err error) { + urlValues := url.Values{} + + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", vectorStoresSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFile creates a new vector store file. +func (c *Client) CreateVectorStoreFile( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileRequest, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFile retrieves a vector store file. +func (c *Client) RetrieveVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStoreFile deletes an existing file. +func (c *Client) DeleteVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, nil) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFiles( + ctx context.Context, + vectorStoreID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFileBatch creates a new vector store file batch. +func (c *Client) CreateVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileBatchRequest, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFileBatch retrieves a vector store file batch. +func (c *Client) RetrieveVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix, batchID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CancelVectorStoreFileBatch cancel a new vector store file batch. +func (c *Client) CancelVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/cancel") + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFilesInBatch( + ctx context.Context, + vectorStoreID string, + batchID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/files", encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} diff --git a/vector_store_test.go b/vector_store_test.go new file mode 100644 index 000000000..58b9a857e --- /dev/null +++ b/vector_store_test.go @@ -0,0 +1,349 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestVectorStore Tests the vector store endpoint of the API using the mocked server. +func TestVectorStore(t *testing.T) { + vectorStoreID := "vs_abc123" + vectorStoreName := "TestStore" + vectorStoreFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + vectorStoreFileBatchID := "vsfb_abc123" + limit := 20 + order := "desc" + after := "vs_abc122" + before := "vs_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files/"+vectorStoreFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "vector_store.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.VectorStoreFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: request.FileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 1, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreFileBatchRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: len(request.FileIDs), + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.VectorStore + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "vectorstore_abc123", + "object": "vector_store.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 0, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoresList{ + LastID: &vectorStoreID, + FirstID: &vectorStoreID, + VectorStores: []openai.VectorStore{ + { + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + t.Run("create_vector_store", func(t *testing.T) { + _, err := client.CreateVectorStore(ctx, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "CreateVectorStore error") + }) + + t.Run("retrieve_vector_store", func(t *testing.T) { + _, err := client.RetrieveVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "RetrieveVectorStore error") + }) + + t.Run("delete_vector_store", func(t *testing.T) { + _, err := client.DeleteVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "DeleteVectorStore error") + }) + + t.Run("list_vector_store", func(t *testing.T) { + _, err := client.ListVectorStores(context.TODO(), openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStores error") + }) + + t.Run("create_vector_store_file", func(t *testing.T) { + _, err := client.CreateVectorStoreFile(context.TODO(), vectorStoreID, openai.VectorStoreFileRequest{ + FileID: vectorStoreFileID, + }) + checks.NoError(t, err, "CreateVectorStoreFile error") + }) + + t.Run("list_vector_store_files", func(t *testing.T) { + _, err := client.ListVectorStoreFiles(ctx, vectorStoreID, openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFiles error") + }) + + t.Run("retrieve_vector_store_file", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "RetrieveVectorStoreFile error") + }) + + t.Run("delete_vector_store_file", func(t *testing.T) { + err := client.DeleteVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "DeleteVectorStoreFile error") + }) + + t.Run("modify_vector_store", func(t *testing.T) { + _, err := client.ModifyVectorStore(ctx, vectorStoreID, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "ModifyVectorStore error") + }) + + t.Run("create_vector_store_file_batch", func(t *testing.T) { + _, err := client.CreateVectorStoreFileBatch(ctx, vectorStoreID, openai.VectorStoreFileBatchRequest{ + FileIDs: []string{vectorStoreFileID}, + }) + checks.NoError(t, err, "CreateVectorStoreFileBatch error") + }) + + t.Run("retrieve_vector_store_file_batch", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "RetrieveVectorStoreFileBatch error") + }) + + t.Run("list_vector_store_files_in_batch", func(t *testing.T) { + _, err := client.ListVectorStoreFilesInBatch( + ctx, + vectorStoreID, + vectorStoreFileBatchID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFilesInBatch error") + }) + + t.Run("cancel_vector_store_file_batch", func(t *testing.T) { + _, err := client.CancelVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "CancelVectorStoreFileBatch error") + }) +}