diff --git a/flights/flight.go b/flights/flight.go index 39f366e..d89162b 100644 --- a/flights/flight.go +++ b/flights/flight.go @@ -21,6 +21,9 @@ import ( const ( flightAirportConst rune = '0' flightCityConst rune = '5' + // maxLocationsPerRequest is the maximum number of locations that can be + // sent in a request. Google Flights only allows 7 locations per request. + maxLocationsPerRequest = 7 ) // Google Flight API requests need different enum values than Google Flight URLs @@ -305,24 +308,52 @@ func getSectionOffers(bytesToDecode []byte, returnDate time.Time) ([]FullOffer, return allOffers, &priceRange, nil } -// GetOffers retrieves offers from the Google Flight search. The city names should be provided in the language -// described by args.Lang. The offers are returned in a slice of [FullOffer]. -// -// GetOffers also returns [*PriceRange], which contains the low and high prices of the search. The values are -// taken from the "View price history" subsection of the search. If the search doesn't have the "View -// price history" subsection, then GetOffers returns nil. -// -// GetPriceGraph returns an error if any of the requests fail or if any of the city names are misspelled. -// -// Requirements are described by the [Args.ValidateOffersArgs] function. -func (s *Session) GetOffers(ctx context.Context, args Args) ([]FullOffer, *PriceRange, error) { - if err := args.ValidateOffersArgs(); err != nil { - return nil, nil, err +type locationGroup struct { + cities []string + airports []string +} + +func splitLocations(cities, airports []string) []locationGroup { + total := len(cities) + len(airports) + if total == 0 { + return []locationGroup{{}} } - finalOffers := []FullOffer{} - var finalPriceRange *PriceRange + type location struct { + value string + isCity bool + } + + combined := make([]location, 0, total) + for _, airport := range airports { + combined = append(combined, location{value: airport}) + } + for _, city := range cities { + combined = append(combined, location{value: city, isCity: true}) + } + + chunks := []locationGroup{} + for start := 0; start < len(combined); start += maxLocationsPerRequest { + end := start + maxLocationsPerRequest + if end > len(combined) { + end = len(combined) + } + + group := locationGroup{} + for _, loc := range combined[start:end] { + if loc.isCity { + group.cities = append(group.cities, loc.value) + continue + } + group.airports = append(group.airports, loc.value) + } + chunks = append(chunks, group) + } + + return chunks +} +func (s *Session) getOffersForArgs(ctx context.Context, args Args) ([]FullOffer, *PriceRange, error) { resp, err := s.doRequestFlights(ctx, args) if err != nil { return nil, nil, err @@ -332,19 +363,68 @@ func (s *Session) GetOffers(ctx context.Context, args Args) ([]FullOffer, *Price body := bufio.NewReader(resp.Body) skipPrefix(body) + offers := []FullOffer{} + var priceRange *PriceRange + for { readLine(body) // skip line bytesToDecode, err := getInnerBytes(body) if err != nil { - return finalOffers, finalPriceRange, nil + return offers, priceRange, nil } - offers, priceRange, _ := getSectionOffers(bytesToDecode, args.ReturnDate) - if offers != nil { - finalOffers = append(finalOffers, offers...) + newOffers, newPriceRange, _ := getSectionOffers(bytesToDecode, args.ReturnDate) + if newOffers != nil { + offers = append(offers, newOffers...) } - if priceRange != nil { - finalPriceRange = priceRange + if newPriceRange != nil { + priceRange = newPriceRange } } } + +// GetOffers retrieves offers from the Google Flight search. The city names should be provided in the language +// described by args.Lang. The offers are returned in a slice of [FullOffer]. +// +// GetOffers also returns [*PriceRange], which contains the low and high prices of the search. The values are +// taken from the "View price history" subsection of the search. If the search doesn't have the "View +// price history" subsection, then GetOffers returns nil. +// +// GetPriceGraph returns an error if any of the requests fail or if any of the city names are misspelled. +// +// Requirements are described by the [Args.ValidateOffersArgs] function. +func (s *Session) GetOffers(ctx context.Context, args Args) ([]FullOffer, *PriceRange, error) { + if err := args.ValidateOffersArgs(); err != nil { + return nil, nil, err + } + + srcGroups := splitLocations(args.SrcCities, args.SrcAirports) + dstGroups := splitLocations(args.DstCities, args.DstAirports) + + finalOffers := []FullOffer{} + var finalPriceRange *PriceRange + + for _, srcGroup := range srcGroups { + for _, dstGroup := range dstGroups { + batchArgs := args + batchArgs.SrcCities = srcGroup.cities + batchArgs.SrcAirports = srcGroup.airports + batchArgs.DstCities = dstGroup.cities + batchArgs.DstAirports = dstGroup.airports + + offers, priceRange, err := s.getOffersForArgs(ctx, batchArgs) + if err != nil { + return nil, nil, err + } + + if offers != nil { + finalOffers = append(finalOffers, offers...) + } + if priceRange != nil { + finalPriceRange = priceRange + } + } + } + + return finalOffers, finalPriceRange, nil +} diff --git a/flights/flight_test.go b/flights/flight_test.go index ba5fc57..88a5067 100644 --- a/flights/flight_test.go +++ b/flights/flight_test.go @@ -274,6 +274,56 @@ func TestGetOffersMock(t *testing.T) { } } +func TestGetOffersSplitsRequests(t *testing.T) { + timeNow = func() time.Time { + t, _ := time.Parse(time.RFC3339, "2024-01-15T00:00:00Z") + return t + } + defer func() { timeNow = time.Now }() + + date, _ := time.Parse(time.RFC3339, "2024-01-20T00:00:00Z") + returnDate, _ := time.Parse(time.RFC3339, "2024-01-25T00:00:00Z") + + srcAirports := []string{"AAA", "AAB", "AAC", "AAD", "AAE", "AAF", "AAG", "AAH"} + dstAirports := []string{"BAA", "BAB", "BAC", "BAD", "BAE", "BAF", "BAG", "BAH"} + + httpClientMock, err := newHttpClientMock( + t, + "testdata/flight.resp", + "testdata/flight.resp", + "testdata/flight.resp", + "testdata/flight.resp", + ) + if err != nil { + t.Fatal(err) + } + + session := &Session{ + client: httpClientMock, + } + + args := Args{ + Date: date, + ReturnDate: returnDate, + SrcAirports: srcAirports, + DstAirports: dstAirports, + Options: OptionsDefault(), + } + + offers, _, err := session.GetOffers(context.Background(), args) + if err != nil { + t.Fatal(err) + } + + if len(offers) != 84 { + t.Fatalf("expected 84 offers from split requests, got %d", len(offers)) + } + + if len(httpClientMock.Responses) != 0 { + t.Fatalf("expected all mock responses to be consumed, remaining: %d", len(httpClientMock.Responses)) + } +} + func TestFlightReqData(t *testing.T) { session, err := New() if err != nil {