diff --git a/openmeter/billing/service/gatheringinvoicependinglines.go b/openmeter/billing/service/gatheringinvoicependinglines.go index 2c271b2c3..ceb48eec8 100644 --- a/openmeter/billing/service/gatheringinvoicependinglines.go +++ b/openmeter/billing/service/gatheringinvoicependinglines.go @@ -275,6 +275,7 @@ type gatheringInvoiceWithFeatureMeters struct { Invoice billing.StandardInvoice FeatureMeters billing.FeatureMeters } + type gatherInScopeLineInput struct { GatheringInvoicesByCurrency map[currencyx.Code]gatheringInvoiceWithFeatureMeters // If set restricts the lines to be included to these IDs, otherwise the AsOf is used @@ -292,12 +293,12 @@ func (s *Service) gatherInScopeLines(ctx context.Context, in gatherInScopeLineIn billableLineIDs := make(map[string]interface{}) for currency, invoice := range in.GatheringInvoicesByCurrency { - lineSrvs, err := s.lineService.FromEntities(invoice.Invoice.Lines.OrEmpty(), invoice.FeatureMeters) + lineSrvs, err := lineservice.FromEntities(invoice.Invoice.Lines.OrEmpty(), invoice.FeatureMeters) if err != nil { return nil, err } - linesWithResolvedPeriods, err := lineSrvs.ResolveBillablePeriod(ctx, lineservice.ResolveBillablePeriodInput{ + linesWithResolvedPeriods, err := lineSrvs.ResolveBillablePeriod(lineservice.ResolveBillablePeriodInput{ AsOf: in.AsOf, ProgressiveBilling: in.ProgressiveBilling, }) @@ -555,17 +556,17 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri l.ChildUniqueReferenceID = nil }) - postSplitAtLineSvc, err := s.lineService.FromEntity(postSplitAtLine, in.FeatureMeters) + postSplitAtLineSvc, err := lineservice.FromEntity(postSplitAtLine, in.FeatureMeters) if err != nil { return res, fmt.Errorf("creating line service: %w", err) } if !postSplitAtLineSvc.IsPeriodEmptyConsideringTruncations() { - gatheringInvoice.Lines.Append(postSplitAtLine) - - if err := postSplitAtLineSvc.Validate(ctx, &gatheringInvoice); err != nil { + if err := postSplitAtLine.Validate(); err != nil { return res, fmt.Errorf("validating post split line: %w", err) } + + gatheringInvoice.Lines.Append(postSplitAtLine) } // Let's update the original line to only contain the period up to the splitAt time @@ -574,7 +575,9 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri line.SplitLineGroupID = lo.ToPtr(splitLineGroupID) line.ChildUniqueReferenceID = nil - preSplitAtLineSvc, err := s.lineService.FromEntity(line, in.FeatureMeters) + preSplitAtLine := line + + preSplitAtLineSvc, err := lineservice.FromEntity(line, in.FeatureMeters) if err != nil { return res, fmt.Errorf("creating line service: %w", err) } @@ -583,7 +586,7 @@ func (s *Service) splitGatheringInvoiceLine(ctx context.Context, in splitGatheri if preSplitAtLineSvc.IsPeriodEmptyConsideringTruncations() { line.DeletedAt = lo.ToPtr(clock.Now()) } else { - if err := preSplitAtLineSvc.Validate(ctx, &gatheringInvoice); err != nil { + if err := preSplitAtLine.Validate(); err != nil { return res, fmt.Errorf("validating pre split line: %w", err) } } @@ -824,16 +827,17 @@ func (s *Service) moveLinesToInvoice(ctx context.Context, in moveLinesToInvoiceI return slices.Contains(in.LineIDsToMove, line.ID) }) - if len(linesToMove) != len(in.LineIDsToMove) { - return nil, fmt.Errorf("lines to move[%d] must contain the same number of lines as line IDs to move[%d]", len(linesToMove), len(in.LineIDsToMove)) + for _, line := range linesToMove { + if line.Currency != dstInvoice.Currency { + return nil, fmt.Errorf("line[%s]: currency[%s] is not equal to target invoice currency[%s]", line.ID, line.Currency, dstInvoice.Currency) + } } - linesToAssociate, err := s.lineService.FromEntities(linesToMove, in.FeatureMeters) - if err != nil { - return nil, fmt.Errorf("creating line services for lines to move: %w", err) + if len(linesToMove) != len(in.LineIDsToMove) { + return nil, fmt.Errorf("lines to move[%d] must contain the same number of lines as line IDs to move[%d]", len(linesToMove), len(in.LineIDsToMove)) } - if err := linesToAssociate.ValidateForInvoice(ctx, &dstInvoice); err != nil { + if err := linesToMove.Validate(); err != nil { return nil, fmt.Errorf("validating lines to move: %w", err) } diff --git a/openmeter/billing/service/invoice.go b/openmeter/billing/service/invoice.go index 1f0306a1b..75674d5f9 100644 --- a/openmeter/billing/service/invoice.go +++ b/openmeter/billing/service/invoice.go @@ -150,15 +150,15 @@ func (s *Service) recalculateGatheringInvoice(ctx context.Context, in recalculat return invoice, fmt.Errorf("fetching profile: %w", err) } + if customerProfile.Customer == nil { + return invoice, fmt.Errorf("customer profile is nil") + } + featureMeters, err := s.resolveFeatureMeters(ctx, invoice.Lines.OrEmpty()) if err != nil { return invoice, fmt.Errorf("resolving feature meters: %w", err) } - if customerProfile.Customer == nil { - return invoice, fmt.Errorf("customer profile is nil") - } - inScopeLines := lo.Filter(invoice.Lines.OrEmpty(), func(line *billing.StandardLine, _ int) bool { return line.DeletedAt == nil }) @@ -167,14 +167,14 @@ func (s *Service) recalculateGatheringInvoice(ctx context.Context, in recalculat return invoice, fmt.Errorf("snapshotting lines: %w", err) } - inScopeLineSvcs, err := s.lineService.FromEntities(inScopeLines, featureMeters) + inScopeLineSvcs, err := lineservice.FromEntities(inScopeLines, featureMeters) if err != nil { return invoice, fmt.Errorf("creating line services: %w", err) } hasInvoicableLines := mo.Option[bool]{} for _, lineSvc := range inScopeLineSvcs { - period, err := lineSvc.CanBeInvoicedAsOf(ctx, lineservice.CanBeInvoicedAsOfInput{ + period, err := lineSvc.CanBeInvoicedAsOf(lineservice.CanBeInvoicedAsOfInput{ AsOf: now, ProgressiveBilling: customerProfile.MergedProfile.WorkflowConfig.Invoicing.ProgressiveBilling, }) @@ -190,7 +190,6 @@ func (s *Service) recalculateGatheringInvoice(ctx context.Context, in recalculat invoice.QuantitySnapshotedAt = lo.ToPtr(now) if err := s.invoiceCalculator.CalculateGatheringInvoiceWithLiveData(&invoice, invoicecalc.CalculatorDependencies{ - LineService: s.lineService, FeatureMeters: featureMeters, }); err != nil { return invoice, fmt.Errorf("calculating invoice: %w", err) @@ -570,18 +569,11 @@ func (s *Service) UpdateInvoice(ctx context.Context, input billing.UpdateInvoice return billing.StandardInvoice{}, fmt.Errorf("editing invoice: %w", err) } - featureMeters, err := s.resolveFeatureMeters(ctx, invoice.Lines.OrEmpty()) - if err != nil { - return billing.StandardInvoice{}, fmt.Errorf("resolving feature meters: %w", err) - } - - normalizedLines, err := invoice.Lines.WithNormalizedValues() + invoice.Lines, err = invoice.Lines.WithNormalizedValues() if err != nil { return billing.StandardInvoice{}, fmt.Errorf("normalizing lines: %w", err) } - invoice.Lines = normalizedLines - if err := s.invoiceCalculator.CalculateGatheringInvoice(&invoice); err != nil { return billing.StandardInvoice{}, fmt.Errorf("calculating invoice[%s]: %w", invoice.ID, err) } @@ -592,6 +584,11 @@ func (s *Service) UpdateInvoice(ctx context.Context, input billing.UpdateInvoice } } + featureMeters, err := s.resolveFeatureMeters(ctx, invoice.Lines.OrEmpty()) + if err != nil { + return billing.StandardInvoice{}, fmt.Errorf("resolving feature meters: %w", err) + } + // Check if the new lines are still invoicable if err := s.checkIfLinesAreInvoicable(ctx, &invoice, customerProfile.MergedProfile.WorkflowConfig.Invoicing.ProgressiveBilling, featureMeters); err != nil { return billing.StandardInvoice{}, err @@ -662,23 +659,22 @@ func (s Service) updateInvoice(ctx context.Context, in billing.UpdateInvoiceAdap } func (s Service) checkIfLinesAreInvoicable(ctx context.Context, invoice *billing.StandardInvoice, progressiveBilling bool, featureMeters billing.FeatureMeters) error { - inScopeLineServices, err := s.lineService.FromEntities( - lo.Filter(invoice.Lines.OrEmpty(), func(line *billing.StandardLine, _ int) bool { - return line.DeletedAt == nil - }), - featureMeters, - ) - if err != nil { - return fmt.Errorf("creating line services: %w", err) - } + linesToCheck := lo.Filter(invoice.Lines.OrEmpty(), func(line *billing.StandardLine, _ int) bool { + return line.DeletedAt == nil + }) return errors.Join( - lo.Map(inScopeLineServices, func(lineSvc lineservice.Line, _ int) error { - if err := lineSvc.Validate(ctx, invoice); err != nil { - return fmt.Errorf("validating line[%s]: %w", lineSvc.ID(), err) + lo.Map(linesToCheck, func(line *billing.StandardLine, _ int) error { + if err := line.Validate(); err != nil { + return fmt.Errorf("validating line[%s]: %w", line.ID, err) + } + + lineSvc, err := lineservice.FromEntity(line, featureMeters) + if err != nil { + return fmt.Errorf("creating line service: %w", err) } - period, err := lineSvc.CanBeInvoicedAsOf(ctx, lineservice.CanBeInvoicedAsOfInput{ + period, err := lineSvc.CanBeInvoicedAsOf(lineservice.CanBeInvoicedAsOfInput{ AsOf: lineSvc.InvoiceAt(), ProgressiveBilling: progressiveBilling, }) @@ -799,17 +795,15 @@ func (s Service) SimulateInvoice(ctx context.Context, input billing.SimulateInvo return billing.StandardInvoice{}, fmt.Errorf("resolving feature meters: %w", err) } - inScopeLineSvcs, err := s.lineService.FromEntities(invoice.Lines.OrEmpty(), featureMeters) - if err != nil { - return billing.StandardInvoice{}, fmt.Errorf("creating line services: %w", err) - } - // Let's update the lines and the detailed lines - for _, lineSvc := range inScopeLineSvcs { - if err := lineSvc.Validate(ctx, &invoice); err != nil { - return billing.StandardInvoice{}, billing.ValidationError{ - Err: err, - } + for _, line := range invoice.Lines.OrEmpty() { + if err := line.Validate(); err != nil { + return billing.StandardInvoice{}, fmt.Errorf("validating line[%s]: %w", line.ID, err) + } + + lineSvc, err := lineservice.FromEntity(line, featureMeters) + if err != nil { + return billing.StandardInvoice{}, fmt.Errorf("creating line service: %w", err) } if err := lineSvc.CalculateDetailedLines(); err != nil { @@ -823,7 +817,6 @@ func (s Service) SimulateInvoice(ctx context.Context, input billing.SimulateInvo // Let's simulate a recalculation of the invoice if err := s.invoiceCalculator.Calculate(&invoice, invoicecalc.CalculatorDependencies{ - LineService: s.lineService, FeatureMeters: featureMeters, }); err != nil { return billing.StandardInvoice{}, err diff --git a/openmeter/billing/service/invoicecalc/calculator.go b/openmeter/billing/service/invoicecalc/calculator.go index 5a6b1a5f3..ccd39411f 100644 --- a/openmeter/billing/service/invoicecalc/calculator.go +++ b/openmeter/billing/service/invoicecalc/calculator.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/openmeterio/openmeter/openmeter/billing" - "github.com/openmeterio/openmeter/openmeter/billing/service/lineservice" ) type invoiceCalculatorsByType struct { @@ -50,7 +49,6 @@ type Calculator interface { } type CalculatorDependencies struct { - LineService *lineservice.Service FeatureMeters billing.FeatureMeters } diff --git a/openmeter/billing/service/invoicecalc/details.go b/openmeter/billing/service/invoicecalc/details.go index d1f39820b..2edc5b375 100644 --- a/openmeter/billing/service/invoicecalc/details.go +++ b/openmeter/billing/service/invoicecalc/details.go @@ -7,6 +7,7 @@ import ( "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/billing/service/lineservice" ) func RecalculateDetailedLinesAndTotals(invoice *billing.StandardInvoice, deps CalculatorDependencies) error { @@ -14,7 +15,7 @@ func RecalculateDetailedLinesAndTotals(invoice *billing.StandardInvoice, deps Ca return errors.New("cannot recaulculate invoice without expanded lines") } - lines, err := deps.LineService.FromEntities(invoice.Lines.OrEmpty(), deps.FeatureMeters) + lines, err := lineservice.FromEntities(invoice.Lines.OrEmpty(), deps.FeatureMeters) if err != nil { return fmt.Errorf("creating line services: %w", err) } diff --git a/openmeter/billing/service/lineservice/linebase.go b/openmeter/billing/service/lineservice/linebase.go index 1046b9c33..12de5a59b 100644 --- a/openmeter/billing/service/lineservice/linebase.go +++ b/openmeter/billing/service/lineservice/linebase.go @@ -1,7 +1,6 @@ package lineservice import ( - "context" "time" "github.com/openmeterio/openmeter/openmeter/billing" @@ -41,7 +40,6 @@ var _ LineBase = (*lineBase)(nil) type lineBase struct { line *billing.StandardLine - service *Service featureMeters billing.FeatureMeters currency currencyx.Calculator } @@ -70,16 +68,6 @@ func (l lineBase) Period() billing.Period { return l.line.Period } -func (l lineBase) Validate(ctx context.Context, invoice *billing.StandardInvoice) error { - if l.line.Currency != invoice.Currency || l.line.Currency == "" { - return billing.ValidationError{ - Err: billing.ErrInvoiceLineCurrencyMismatch, - } - } - - return nil -} - func (l lineBase) IsLastInPeriod() bool { if l.line.SplitLineGroupID == nil { return true @@ -116,10 +104,6 @@ func (l lineBase) IsDeleted() bool { return l.line.DeletedAt != nil } -func (l lineBase) Service() *Service { - return l.service -} - func (l lineBase) ResetTotals() { l.line.Totals = billing.Totals{} } diff --git a/openmeter/billing/service/lineservice/meters.go b/openmeter/billing/service/lineservice/meters.go deleted file mode 100644 index 1b1ecfc0d..000000000 --- a/openmeter/billing/service/lineservice/meters.go +++ /dev/null @@ -1,130 +0,0 @@ -package lineservice - -import ( - "context" - "fmt" - "slices" - - "github.com/alpacahq/alpacadecimal" - - "github.com/openmeterio/openmeter/openmeter/billing" - "github.com/openmeterio/openmeter/openmeter/meter" - "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" - "github.com/openmeterio/openmeter/openmeter/streaming" -) - -type getFeatureUsageInput struct { - Line *billing.StandardLine - Meter meter.Meter - Feature feature.Feature - Customer billing.InvoiceCustomer -} - -func (i getFeatureUsageInput) Validate() error { - if i.Line == nil { - return fmt.Errorf("line is required") - } - - if slices.Contains([]meter.MeterAggregation{ - meter.MeterAggregationAvg, - meter.MeterAggregationMin, - }, i.Meter.Aggregation) { - if i.Line.SplitLineHierarchy != nil { - return fmt.Errorf("aggregation %s is not supported for split lines", i.Meter.Aggregation) - } - } - - if err := i.Customer.Validate(); err != nil { - return fmt.Errorf("customer: %w", err) - } - - return nil -} - -type featureUsageResponse struct { - // LinePeriodQty is the quantity of the usage for the line for the period - LinePeriodQty alpacadecimal.Decimal - // PreLinePeriodQty is the quantity of the usage for the line for the period before the current period - PreLinePeriodQty alpacadecimal.Decimal -} - -func (s *Service) getFeatureUsage(ctx context.Context, in getFeatureUsageInput) (*featureUsageResponse, error) { - // Validation - if err := in.Validate(); err != nil { - return nil, err - } - - meterQueryParams := streaming.QueryParams{ - FilterCustomer: []streaming.Customer{in.Customer}, - From: &in.Line.Period.Start, - To: &in.Line.Period.End, - FilterGroupBy: in.Feature.MeterGroupByFilters, - } - - lineHierarchy := in.Line.SplitLineHierarchy - - // If we are the first line in the split, we don't need to calculate the pre period - if lineHierarchy == nil || lineHierarchy.Group.ServicePeriod.Start.Equal(in.Line.Period.Start) { - meterValues, err := s.StreamingConnector.QueryMeter( - ctx, - in.Line.Namespace, - in.Meter, - meterQueryParams, - ) - if err != nil { - return nil, fmt.Errorf("querying line[%s] meter[%s]: %w", in.Line.ID, in.Meter.Key, err) - } - - return &featureUsageResponse{ - LinePeriodQty: summarizeMeterQueryRow(meterValues), - }, nil - } - - // Let's calculate [parent.start ... line.start] values - preLineQuery := meterQueryParams - preLineQuery.From = &lineHierarchy.Group.ServicePeriod.Start - preLineQuery.To = &in.Line.Period.Start - - preLineResult, err := s.StreamingConnector.QueryMeter( - ctx, - in.Line.Namespace, - in.Meter, - preLineQuery, - ) - if err != nil { - return nil, fmt.Errorf("querying pre line[%s] period meter[%s]: %w", in.Line.ID, in.Meter.Key, err) - } - - preLineQty := summarizeMeterQueryRow(preLineResult) - - // Let's calculate [parent.start ... line.end] values - upToLineEnd := meterQueryParams - upToLineEnd.From = &lineHierarchy.Group.ServicePeriod.Start - upToLineEnd.To = &in.Line.Period.End - - upToLineEndResult, err := s.StreamingConnector.QueryMeter( - ctx, - in.Line.Namespace, - in.Meter, - upToLineEnd, - ) - if err != nil { - return nil, fmt.Errorf("querying up to line[%s] end meter[%s]: %w", in.Line.ID, in.Meter.Key, err) - } - - upToLineQty := summarizeMeterQueryRow(upToLineEndResult) - - return &featureUsageResponse{ - LinePeriodQty: upToLineQty.Sub(preLineQty), - PreLinePeriodQty: preLineQty, - }, nil -} - -func summarizeMeterQueryRow(in []meter.MeterQueryRow) alpacadecimal.Decimal { - sum := alpacadecimal.Decimal{} - for _, row := range in { - sum = sum.Add(alpacadecimal.NewFromFloat(row.Value)) - } - - return sum -} diff --git a/openmeter/billing/service/lineservice/service.go b/openmeter/billing/service/lineservice/service.go index c9a26a15b..32e91948f 100644 --- a/openmeter/billing/service/lineservice/service.go +++ b/openmeter/billing/service/lineservice/service.go @@ -1,8 +1,6 @@ package lineservice import ( - "context" - "errors" "fmt" "time" @@ -10,47 +8,19 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/productcatalog" - "github.com/openmeterio/openmeter/openmeter/streaming" "github.com/openmeterio/openmeter/pkg/slicesx" ) -type Service struct { - Config -} - -type Config struct { - StreamingConnector streaming.Connector -} - -func (c Config) Validate() error { - if c.StreamingConnector == nil { - return fmt.Errorf("streaming connector is required") - } - - return nil -} - -func New(in Config) (*Service, error) { - if err := in.Validate(); err != nil { - return nil, err - } - - return &Service{ - Config: in, - }, nil -} - -func (s *Service) FromEntity(line *billing.StandardLine, featureMeters billing.FeatureMeters) (Line, error) { +func FromEntity(line *billing.StandardLine, featureMeters billing.FeatureMeters) (Line, error) { currencyCalc, err := line.Currency.Calculator() if err != nil { return nil, fmt.Errorf("creating currency calculator: %w", err) } base := lineBase{ - service: s, - featureMeters: featureMeters, line: line, currency: currencyCalc, + featureMeters: featureMeters, } if line.UsageBased.Price.Type() == productcatalog.FlatPriceType { @@ -64,14 +34,14 @@ func (s *Service) FromEntity(line *billing.StandardLine, featureMeters billing.F }, nil } -func (s *Service) FromEntities(line []*billing.StandardLine, featureMeters billing.FeatureMeters) (Lines, error) { +func FromEntities(line []*billing.StandardLine, featureMeters billing.FeatureMeters) (Lines, error) { return slicesx.MapWithErr(line, func(l *billing.StandardLine) (Line, error) { - return s.FromEntity(l, featureMeters) + return FromEntity(l, featureMeters) }) } // UpdateTotalsFromDetailedLines is a helper method to update the totals of a line from its detailed lines. -func (s *Service) UpdateTotalsFromDetailedLines(line *billing.StandardLine) error { +func UpdateTotalsFromDetailedLines(line *billing.StandardLine) error { // Calculate the line totals for idx, detailedLine := range line.DetailedLines { if detailedLine.DeletedAt != nil { @@ -87,7 +57,7 @@ func (s *Service) UpdateTotalsFromDetailedLines(line *billing.StandardLine) erro } // WARNING: Even if tempting to add discounts etc. here to the totals, we should always keep the logic as is. - // The usageBasedLine will never be syncorinzed directly to stripe or other apps, only the detailed lines. + // The usageBasedLine will never be synchronized directly to stripe or other apps, only the detailed lines. // // Given that the external systems will have their own logic for calculating the totals, we cannot expect // any custom logic implemented here to be carried over to the external systems. @@ -122,39 +92,17 @@ type ( type Line interface { LineBase - Service() *Service - // IsPeriodEmptyConsideringTruncations returns true if the line has an empty period. This is different from Period.IsEmpty() as // this method does any truncation for usage based lines. IsPeriodEmptyConsideringTruncations() bool - Validate(context.Context, *billing.StandardInvoice) error - CanBeInvoicedAsOf(context.Context, CanBeInvoicedAsOfInput) (*billing.Period, error) + CanBeInvoicedAsOf(CanBeInvoicedAsOfInput) (*billing.Period, error) CalculateDetailedLines() error UpdateTotals() error } type Lines []Line -func (s Lines) ValidateForInvoice(ctx context.Context, invoice *billing.StandardInvoice) error { - return errors.Join(lo.Map(s, func(line Line, idx int) error { - if line == nil { - return fmt.Errorf("line[%d] is nil", idx) - } - - if err := line.Validate(ctx, invoice); err != nil { - id := line.ID() - if id == "" { - id = fmt.Sprintf("line[%d]", idx) - } - - return fmt.Errorf("line[%s]: %w", id, err) - } - - return nil - })...) -} - func (s Lines) ToEntities() []*billing.StandardLine { return lo.Map(s, func(service Line, _ int) *billing.StandardLine { return service.ToEntity() @@ -166,10 +114,10 @@ type LineWithBillablePeriod struct { BillablePeriod billing.Period } -func (s Lines) ResolveBillablePeriod(ctx context.Context, in ResolveBillablePeriodInput) ([]LineWithBillablePeriod, error) { +func (s Lines) ResolveBillablePeriod(in ResolveBillablePeriodInput) ([]LineWithBillablePeriod, error) { out := make([]LineWithBillablePeriod, 0, len(s)) for _, lineSrv := range s { - billablePeriod, err := lineSrv.CanBeInvoicedAsOf(ctx, in) + billablePeriod, err := lineSrv.CanBeInvoicedAsOf(in) if err != nil { return nil, fmt.Errorf("checking if line can be invoiced: %w", err) } diff --git a/openmeter/billing/service/lineservice/usagebasedline.go b/openmeter/billing/service/lineservice/usagebasedline.go index f11c4c5ad..023c0d24d 100644 --- a/openmeter/billing/service/lineservice/usagebasedline.go +++ b/openmeter/billing/service/lineservice/usagebasedline.go @@ -1,7 +1,6 @@ package lineservice import ( - "context" "fmt" "github.com/alpacahq/alpacadecimal" @@ -43,31 +42,7 @@ type usageBasedLine struct { lineBase } -func (l usageBasedLine) Validate(ctx context.Context, targetInvoice *billing.StandardInvoice) error { - if _, err := l.featureMeters.Get(l.line.UsageBased.FeatureKey, true); err != nil { - return err - } - - if err := l.lineBase.Validate(ctx, targetInvoice); err != nil { - return err - } - - if err := targetInvoice.Customer.Validate(); err != nil { - return billing.ValidationError{ - Err: billing.ErrInvoiceCreateUBPLineCustomerUsageAttributionInvalid, - } - } - - if l.line.Period.Truncate(streaming.MinimumWindowSizeDuration).IsEmpty() { - return billing.ValidationError{ - Err: billing.ErrInvoiceCreateUBPLinePeriodIsEmpty, - } - } - - return nil -} - -func (l usageBasedLine) CanBeInvoicedAsOf(ctx context.Context, in CanBeInvoicedAsOfInput) (*billing.Period, error) { +func (l usageBasedLine) CanBeInvoicedAsOf(in CanBeInvoicedAsOfInput) (*billing.Period, error) { if !in.ProgressiveBilling { // If we are not doing progressive billing, we can only bill the line if asof >= line.period.end if in.AsOf.Before(l.line.Period.End) { @@ -142,7 +117,7 @@ func (l usageBasedLine) CanBeInvoicedAsOf(ctx context.Context, in CanBeInvoicedA } func (l *usageBasedLine) UpdateTotals() error { - return l.service.UpdateTotalsFromDetailedLines(l.line) + return UpdateTotalsFromDetailedLines(l.line) } func (l *usageBasedLine) CalculateDetailedLines() error { diff --git a/openmeter/billing/service/lineservice/usagebasedline_test.go b/openmeter/billing/service/lineservice/usagebasedline_test.go index e3aeebdff..1e3f79608 100644 --- a/openmeter/billing/service/lineservice/usagebasedline_test.go +++ b/openmeter/billing/service/lineservice/usagebasedline_test.go @@ -29,6 +29,11 @@ var ubpTestFullPeriod = billing.Period{ End: lo.Must(time.Parse(time.RFC3339, "2021-01-02T00:00:00Z")), } +type featureUsageResponse struct { + LinePeriodQty alpacadecimal.Decimal + PreLinePeriodQty alpacadecimal.Decimal +} + type ubpCalculationTestCase struct { price productcatalog.Price discounts billing.Discounts diff --git a/openmeter/billing/service/lineservice/usagebasedlineflat.go b/openmeter/billing/service/lineservice/usagebasedlineflat.go index 1a8d14f2e..df3f96d0e 100644 --- a/openmeter/billing/service/lineservice/usagebasedlineflat.go +++ b/openmeter/billing/service/lineservice/usagebasedlineflat.go @@ -1,8 +1,6 @@ package lineservice import ( - "context" - "errors" "fmt" "github.com/openmeterio/openmeter/openmeter/billing" @@ -14,37 +12,7 @@ type ubpFlatFeeLine struct { lineBase } -func (l ubpFlatFeeLine) Validate(ctx context.Context, targetInvoice *billing.StandardInvoice) error { - var outErr []error - - if l.line.UsageBased.FeatureKey != "" { - _, err := l.featureMeters.Get(l.line.UsageBased.FeatureKey, false) - if err != nil { - outErr = append(outErr, fmt.Errorf("fetching feature[%s]: %w", l.line.UsageBased.FeatureKey, err)) - } - } - - if err := l.lineBase.Validate(ctx, targetInvoice); err != nil { - outErr = append(outErr, err) - } - - // Usage discounts are not allowed - // TODO[later]: Once we have cleaned up the line types, let's move as much as possible to the line's validation - if l.line.RateCardDiscounts.Usage != nil { - outErr = append(outErr, errors.New("usage discounts are not supported for usage based flat fee lines")) - } - - // Percentage discounts are allowed - if l.line.RateCardDiscounts.Percentage != nil { - if err := l.line.RateCardDiscounts.Percentage.Validate(); err != nil { - outErr = append(outErr, err) - } - } - - return errors.Join(outErr...) -} - -func (l ubpFlatFeeLine) CanBeInvoicedAsOf(_ context.Context, in CanBeInvoicedAsOfInput) (*billing.Period, error) { +func (l ubpFlatFeeLine) CanBeInvoicedAsOf(in CanBeInvoicedAsOfInput) (*billing.Period, error) { if !in.AsOf.Before(l.line.InvoiceAt) { return &l.line.Period, nil } @@ -77,7 +45,7 @@ func (l ubpFlatFeeLine) CalculateDetailedLines() error { } func (l *ubpFlatFeeLine) UpdateTotals() error { - return l.service.UpdateTotalsFromDetailedLines(l.line) + return UpdateTotalsFromDetailedLines(l.line) } func (l ubpFlatFeeLine) IsPeriodEmptyConsideringTruncations() bool { diff --git a/openmeter/billing/service/service.go b/openmeter/billing/service/service.go index 63cf163c0..26ec5e96f 100644 --- a/openmeter/billing/service/service.go +++ b/openmeter/billing/service/service.go @@ -9,7 +9,6 @@ import ( "github.com/openmeterio/openmeter/openmeter/app" "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/billing/service/invoicecalc" - "github.com/openmeterio/openmeter/openmeter/billing/service/lineservice" "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/meter" "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" @@ -30,8 +29,7 @@ type Service struct { meterService meter.Service streamingConnector streaming.Connector - lineService *lineservice.Service - publisher eventbus.Publisher + publisher eventbus.Publisher advancementStrategy billing.AdvancementStrategy fsNamespaceLockdown []string @@ -116,15 +114,6 @@ func New(config Config) (*Service, error) { invoiceCalculator: invoicecalc.New(), } - lineSvc, err := lineservice.New(lineservice.Config{ - StreamingConnector: config.StreamingConnector, - }) - if err != nil { - return nil, fmt.Errorf("creating line service: %w", err) - } - - svc.lineService = lineSvc - return svc, nil } diff --git a/openmeter/billing/service/stdinvoiceline.go b/openmeter/billing/service/stdinvoiceline.go index b39a6b48e..8dc9d116c 100644 --- a/openmeter/billing/service/stdinvoiceline.go +++ b/openmeter/billing/service/stdinvoiceline.go @@ -61,11 +61,6 @@ func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.C return nil, nil } - featureMeters, err := s.resolveFeatureMeters(ctx, input.Lines) - if err != nil { - return nil, fmt.Errorf("resolving feature meters: %w", err) - } - // let's resolve the customer's settings customerProfile, err := s.GetCustomerOverride(ctx, billing.GetCustomerOverrideInput{ Customer: input.Customer, @@ -97,25 +92,28 @@ func (s *Service) CreatePendingInvoiceLines(ctx context.Context, input billing.C l.ID = ulid.Make().String() l.InvoiceID = gatheringInvoice.ID - return l.WithNormalizedValues() + normalizedLine, err := l.WithNormalizedValues() + if err != nil { + return nil, fmt.Errorf("normalizing line[%s]: %w", l.ID, err) + } + + if err := normalizedLine.Validate(); err != nil { + return nil, fmt.Errorf("validating line[%s]: %w", l.ID, err) + } + + return normalizedLine, nil }) if err != nil { return nil, fmt.Errorf("mapping lines: %w", err) } - lineServices, err := s.lineService.FromEntities(linesToCreate, featureMeters) - if err != nil { - return nil, fmt.Errorf("creating line services: %w", err) - } - - if err := lineServices.ValidateForInvoice(ctx, &gatheringInvoice); err != nil { - return nil, fmt.Errorf("validating lines: %w", err) - } - gatheringInvoice.Lines.Append(linesToCreate...) - gatheringInvoiceID := gatheringInvoice.ID + if err := gatheringInvoice.Validate(); err != nil { + return nil, fmt.Errorf("validating gathering invoice: %w", err) + } + if err := s.invoiceCalculator.CalculateGatheringInvoice(&gatheringInvoice); err != nil { return nil, fmt.Errorf("calculating invoice[%s]: %w", gatheringInvoiceID, err) } diff --git a/openmeter/billing/service/stdinvoicestate.go b/openmeter/billing/service/stdinvoicestate.go index 0bac1f1a7..1aeb95c5a 100644 --- a/openmeter/billing/service/stdinvoicestate.go +++ b/openmeter/billing/service/stdinvoicestate.go @@ -612,7 +612,6 @@ func (m *InvoiceStateMachine) calculateInvoice(ctx context.Context) error { } return m.Calculator.Calculate(&m.Invoice, invoicecalc.CalculatorDependencies{ - LineService: m.Service.lineService, FeatureMeters: featureMeters, }) } diff --git a/openmeter/billing/stdinvoice.go b/openmeter/billing/stdinvoice.go index 270a87ac2..36b1b788b 100644 --- a/openmeter/billing/stdinvoice.go +++ b/openmeter/billing/stdinvoice.go @@ -291,6 +291,14 @@ func (i StandardInvoice) Validate() error { outErr = errors.Join(outErr, ValidationWithFieldPrefix("lines", err)) } + if i.Lines.IsPresent() { + for _, line := range i.Lines.OrEmpty() { + if line.Currency != i.Currency { + outErr = errors.Join(outErr, fmt.Errorf("line[%s]: currency[%s] is not equal to invoice currency[%s]", line.ID, line.Currency, i.Currency)) + } + } + } + return outErr } diff --git a/openmeter/billing/stdinvoiceline.go b/openmeter/billing/stdinvoiceline.go index 43f8b9f99..d79c56b14 100644 --- a/openmeter/billing/stdinvoiceline.go +++ b/openmeter/billing/stdinvoiceline.go @@ -86,6 +86,18 @@ func (i StandardLineBase) Validate() error { errs = append(errs, fmt.Errorf("invalid managed by %s", i.ManagedBy)) } + if i.RateCardDiscounts.Percentage != nil { + if err := i.RateCardDiscounts.Percentage.Validate(); err != nil { + errs = append(errs, fmt.Errorf("percentage discounts: %w", err)) + } + } + + if i.RateCardDiscounts.Usage != nil { + if err := i.RateCardDiscounts.Usage.Validate(); err != nil { + errs = append(errs, fmt.Errorf("usage discounts: %w", err)) + } + } + return errors.Join(errs...) } @@ -285,6 +297,16 @@ func (i *StandardLine) SaveDBSnapshot() { func (i StandardLine) Validate() error { var errs []error + + // Fail fast cases (most of the validation logic uses these) + if i.UsageBased == nil { + return errors.New("usage based line is required") + } + + if i.UsageBased.Price == nil { + return errors.New("usage based line price is required") + } + if err := i.StandardLineBase.Validate(); err != nil { errs = append(errs, err) } @@ -297,26 +319,36 @@ func (i StandardLine) Validate() error { errs = append(errs, fmt.Errorf("detailed lines: %w", err)) } - if err := i.ValidateUsageBased(); err != nil { - errs = append(errs, err) + for _, detailedLine := range i.DetailedLines { + if detailedLine.Currency != i.Currency { + errs = append(errs, fmt.Errorf("detailed line[%s]: currency[%s] is not equal to line currency[%s]", detailedLine.ID, detailedLine.Currency, i.Currency)) + } } - if err := i.RateCardDiscounts.ValidateForPrice(i.UsageBased.Price); err != nil { - errs = append(errs, fmt.Errorf("rateCardDiscounts: %w", err)) + if err := i.UsageBased.Validate(); err != nil { + errs = append(errs, err) } - return errors.Join(errs...) -} - -func (i StandardLine) ValidateUsageBased() error { - var errs []error + if i.UsageBased.Price.Type() != productcatalog.FlatPriceType { + if i.InvoiceAt. + Truncate(streaming.MinimumWindowSizeDuration). + Before(i.Period.Truncate(streaming.MinimumWindowSizeDuration).End) { + errs = append(errs, fmt.Errorf("invoice at (%s) must be after period end (%s) for usage based line", i.InvoiceAt, i.Period.Truncate(streaming.MinimumWindowSizeDuration).End)) + } - if err := i.UsageBased.Validate(); err != nil { - errs = append(errs, err) + if i.Period.Truncate(streaming.MinimumWindowSizeDuration).IsEmpty() { + errs = append(errs, ValidationError{ + Err: ErrInvoiceCreateUBPLinePeriodIsEmpty, + }) + } + } else { + if i.RateCardDiscounts.Usage != nil { + errs = append(errs, fmt.Errorf("usage discounts are not allowed for flat price lines")) + } } - if i.DependsOnMeteredQuantity() && i.InvoiceAt.Before(i.Period.Truncate(streaming.MinimumWindowSizeDuration).End) { - errs = append(errs, fmt.Errorf("invoice at (%s) must be after period end (%s) for usage based line", i.InvoiceAt, i.Period.Truncate(streaming.MinimumWindowSizeDuration).End)) + if err := i.RateCardDiscounts.ValidateForPrice(i.UsageBased.Price); err != nil { + errs = append(errs, fmt.Errorf("rateCardDiscounts: %w", err)) } return errors.Join(errs...)