Skip to content

Commit 45625e5

Browse files
authored
Make limit int32 (#1609)
1 parent b8143f6 commit 45625e5

File tree

6 files changed

+28
-24
lines changed

6 files changed

+28
-24
lines changed

src/pkg/cli/client/byoc/aws/byoc.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ func (b *ByocAws) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (cli
686686
etag = "" // no need to filter by etag
687687
}
688688
} else {
689-
tailStream, err = b.driver.QueryTaskID(ctx, etag, time.Time{}, time.Now(), int(req.Limit))
689+
tailStream, err = b.driver.QueryTaskID(ctx, etag, time.Time{}, time.Now(), req.Limit)
690690
if err == nil {
691691
b.cdTaskArn, err = b.driver.GetTaskArn(etag)
692692
etag = "" // no need to filter by etag
@@ -713,12 +713,11 @@ func (b *ByocAws) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (cli
713713
lgis...,
714714
)
715715
} else {
716-
limit := int(req.Limit)
717716
evtsChan, errsChan := ecs.QueryLogGroups(
718717
ctx,
719718
start,
720719
end,
721-
limit,
720+
req.Limit,
722721
lgis...,
723722
)
724723
if evtsChan == nil {

src/pkg/cli/client/byoc/gcp/byoc.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ func (b *ByocGcp) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (cli
561561
if req.Follow {
562562
subscribeStream.StartFollow(since)
563563
} else {
564-
subscribeStream.Start(int(req.Limit))
564+
subscribeStream.Start(req.Limit)
565565
}
566566

567567
var cancel context.CancelCauseFunc
@@ -624,7 +624,7 @@ func (b *ByocGcp) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (cli
624624
if req.Follow {
625625
logStream.StartFollow(startTime)
626626
} else {
627-
logStream.Start(int(req.Limit))
627+
logStream.Start(req.Limit)
628628
}
629629
return logStream, nil
630630
}

src/pkg/cli/client/byoc/gcp/stream.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func (s *ServerStream[T]) StartFollow(start time.Time) {
125125
}()
126126
}
127127

128-
func (s *ServerStream[T]) Start(limit int) {
128+
func (s *ServerStream[T]) Start(limit int32) {
129129
query := s.query.GetQuery()
130130
term.Debugf("Query logs with query: \n%v", query)
131131
go func() {
@@ -146,7 +146,7 @@ func (s *ServerStream[T]) queryHead(query string) {
146146
}
147147
}
148148

149-
func (s *ServerStream[T]) queryTail(query string, limit int) {
149+
func (s *ServerStream[T]) queryTail(query string, limit int32) {
150150
lister, err := s.gcp.ListLogEntries(s.ctx, query, gcp.OrderDescending)
151151
if err != nil {
152152
s.errCh <- err
@@ -171,7 +171,7 @@ func (s *ServerStream[T]) queryTail(query string, limit int) {
171171
}
172172
}
173173

174-
func (s *ServerStream[T]) listToBuffer(lister *gcp.Lister, limit int) ([]*T, error) {
174+
func (s *ServerStream[T]) listToBuffer(lister *gcp.Lister, limit int32) ([]*T, error) {
175175
received := 0
176176
buffer := make([]*T, 0, limit)
177177
for range limit {
@@ -191,7 +191,7 @@ func (s *ServerStream[T]) listToBuffer(lister *gcp.Lister, limit int) ([]*T, err
191191
return buffer, nil
192192
}
193193

194-
func (s *ServerStream[T]) listToChannel(lister *gcp.Lister, limit int) error {
194+
func (s *ServerStream[T]) listToChannel(lister *gcp.Lister, limit int32) error {
195195
received := 0
196196
for {
197197
entry, err := lister.Next()
@@ -206,7 +206,7 @@ func (s *ServerStream[T]) listToChannel(lister *gcp.Lister, limit int) error {
206206
s.respCh <- resp
207207
}
208208
received += len(resps)
209-
if limit > 0 && received >= limit {
209+
if limit > 0 && received >= int(limit) {
210210
return io.EOF
211211
}
212212
}

src/pkg/clouds/aws/ecs/logs.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,15 @@ func TailLogGroup(ctx context.Context, input LogGroupInput) (LiveTailStream, err
124124
return slto.GetStream(), nil
125125
}
126126

127-
func QueryLogGroups(ctx context.Context, start, end time.Time, limit int, logGroups ...LogGroupInput) (<-chan LogEvent, <-chan error) {
128-
// Gather logs from the CD task, kaniko, ECS events, and all services
127+
func QueryLogGroups(ctx context.Context, start, end time.Time, limit int32, logGroups ...LogGroupInput) (<-chan LogEvent, <-chan error) {
129128
var evtsChan chan LogEvent
130129
var errChan chan error
131130
for _, lgi := range logGroups {
132131
lgEvtChan := make(chan LogEvent)
133132
// Start a go routine for each log group
134133
go func(lgi LogGroupInput) {
135134
defer close(lgEvtChan)
136-
if err := QueryLogGroup(ctx, lgi, start, end, func(logEvents []LogEvent) error {
135+
if err := QueryLogGroup(ctx, lgi, start, end, limit, func(logEvents []LogEvent) error {
137136
for _, event := range logEvents {
138137
lgEvtChan <- event
139138
}
@@ -145,22 +144,22 @@ func QueryLogGroups(ctx context.Context, start, end time.Time, limit int, logGro
145144
evtsChan = mergeLogEventChan(evtsChan, lgEvtChan) // Merge sort the log events based on timestamp
146145
// take the last n events only
147146
if limit > 0 {
148-
evtsChan = takeLastN(evtsChan, limit)
147+
evtsChan = takeLastN(evtsChan, int(limit))
149148
}
150149
}
151150
return evtsChan, errChan
152151
}
153152

154-
func QueryLogGroup(ctx context.Context, input LogGroupInput, start, end time.Time, cb func([]LogEvent) error) error {
153+
func QueryLogGroup(ctx context.Context, input LogGroupInput, start, end time.Time, limit int32, cb func([]LogEvent) error) error {
155154
region := region.FromArn(input.LogGroupARN)
156155
cw, err := newCloudWatchLogsClient(ctx, region)
157156
if err != nil {
158157
return err
159158
}
160-
return filterLogEvents(ctx, cw, input, start, end, cb)
159+
return filterLogEvents(ctx, cw, input, start, end, limit, cb)
161160
}
162161

163-
func filterLogEvents(ctx context.Context, cw *cloudwatchlogs.Client, lgi LogGroupInput, start, end time.Time, cb func([]LogEvent) error) error {
162+
func filterLogEvents(ctx context.Context, cw *cloudwatchlogs.Client, lgi LogGroupInput, start, end time.Time, limit int32, cb func([]LogEvent) error) error {
164163
var pattern *string
165164
if lgi.LogEventFilterPattern != "" {
166165
pattern = &lgi.LogEventFilterPattern
@@ -189,6 +188,11 @@ func filterLogEvents(ctx context.Context, cw *cloudwatchlogs.Client, lgi LogGrou
189188
params.LogStreamNamePrefix = &lgi.LogStreamNamePrefix
190189
}
191190
for {
191+
if limit > 0 {
192+
// Specifying the limit parameter only guarantees that a single page doesn't return more log events than the
193+
// specified limit, but it might return fewer events than the limit. This is the expected API behavior.
194+
params.Limit = ptr.Int32(limit)
195+
}
192196
fleo, err := cw.FilterLogEvents(ctx, params)
193197
if err != nil {
194198
return err
@@ -209,6 +213,10 @@ func filterLogEvents(ctx context.Context, cw *cloudwatchlogs.Client, lgi LogGrou
209213
if fleo.NextToken == nil {
210214
return nil
211215
}
216+
if limit > 0 && len(events) >= int(limit) {
217+
return nil
218+
}
219+
limit -= int32(len(events)) // #nosec G115 - always safe because len(events) <= limit
212220
params.NextToken = fleo.NextToken
213221
}
214222
}

src/pkg/clouds/aws/ecs/stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func QueryAndTailLogGroup(ctx context.Context, lgi LogGroupInput, start, end tim
5151
end = time.Now()
5252
}
5353
// Query the logs between the start time and now; TODO: could use a single CloudWatch client for all queries in same region
54-
if err := QueryLogGroup(ctx, lgi, start, end, func(events []LogEvent) error {
54+
if err := QueryLogGroup(ctx, lgi, start, end, 0, func(events []LogEvent) error {
5555
es.ch <- &types.StartLiveTailResponseStreamMemberSessionUpdate{
5656
Value: types.LiveTailSessionUpdate{SessionResults: events},
5757
}

src/pkg/clouds/aws/ecs/tail.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (a *AwsEcs) GetTaskArn(taskID string) (TaskArn, error) {
5757
return &taskArn, nil
5858
}
5959

60-
func (a *AwsEcs) QueryTaskID(ctx context.Context, taskID string, start, end time.Time, limit int) (EventStream[types.StartLiveTailResponseStream], error) {
60+
func (a *AwsEcs) QueryTaskID(ctx context.Context, taskID string, start, end time.Time, limit int32) (EventStream[types.StartLiveTailResponseStream], error) {
6161
if taskID == "" {
6262
return nil, errors.New("taskID is empty")
6363
}
@@ -68,7 +68,8 @@ func (a *AwsEcs) QueryTaskID(ctx context.Context, taskID string, start, end time
6868
}
6969

7070
lgi := LogGroupInput{LogGroupARN: a.LogGroupARN, LogStreamNames: []string{GetCDLogStreamForTaskID(taskID)}}
71-
if err := QueryLogGroup(ctx, lgi, start, end, func(events []LogEvent) error {
71+
// Note: this function only returns once the query is complete, so returning an event stream is somewhat misleading
72+
if err := QueryLogGroup(ctx, lgi, start, end, limit, func(events []LogEvent) error {
7273
es.ch <- &types.StartLiveTailResponseStreamMemberSessionUpdate{
7374
Value: types.LiveTailSessionUpdate{SessionResults: events},
7475
}
@@ -77,10 +78,6 @@ func (a *AwsEcs) QueryTaskID(ctx context.Context, taskID string, start, end time
7778
es.err = err
7879
}
7980

80-
if limit > 0 {
81-
es.ch = takeLastN(es.ch, limit)
82-
}
83-
8481
return es, nil
8582
}
8683

0 commit comments

Comments
 (0)