Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ Join our discord community via [this invite link](https://discord.gg/bxgXW8jJGh)
| <a name="input_key_name"></a> [key\_name](#input\_key\_name) | Key pair name | `string` | `null` | no |
| <a name="input_kms_key_arn"></a> [kms\_key\_arn](#input\_kms\_key\_arn) | Optional CMK Key ARN to be used for Parameter Store. This key must be in the current account. | `string` | `null` | no |
| <a name="input_lambda_architecture"></a> [lambda\_architecture](#input\_lambda\_architecture) | AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions. | `string` | `"arm64"` | no |
| <a name="input_lambda_event_source_mapping_batch_size"></a> [lambda\_event\_source\_mapping\_batch\_size](#input\_lambda\_event\_source\_mapping\_batch\_size) | Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used. | `number` | `10` | no |
| <a name="input_lambda_event_source_mapping_maximum_batching_window_in_seconds"></a> [lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds](#input\_lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds) | Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10. Defaults to 0. | `number` | `0` | no |
| <a name="input_lambda_principals"></a> [lambda\_principals](#input\_lambda\_principals) | (Optional) add extra principals to the role created for execution of the lambda, e.g. for local testing. | <pre>list(object({<br/> type = string<br/> identifiers = list(string)<br/> }))</pre> | `[]` | no |
| <a name="input_lambda_runtime"></a> [lambda\_runtime](#input\_lambda\_runtime) | AWS Lambda runtime. | `string` | `"nodejs22.x"` | no |
| <a name="input_lambda_s3_bucket"></a> [lambda\_s3\_bucket](#input\_lambda\_s3\_bucket) | S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly. | `string` | `null` | no |
Expand Down
15 changes: 9 additions & 6 deletions lambdas/functions/control-plane/src/aws/runners.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,12 @@ describe('create runner with errors', () => {
});

it('test ScaleError with multiple error.', async () => {
createFleetMockWithErrors(['UnfulfillableCapacity', 'SomeError']);
createFleetMockWithErrors(['UnfulfillableCapacity', 'MaxSpotInstanceCountExceeded', 'NotMappedError']);

await expect(createRunner(createRunnerConfig(defaultRunnerConfig))).rejects.toBeInstanceOf(ScaleError);
await expect(createRunner(createRunnerConfig(defaultRunnerConfig))).rejects.toMatchObject({
name: 'ScaleError',
failedInstanceCount: 2,
});
expect(mockEC2Client).toHaveReceivedCommandWith(
CreateFleetCommand,
expectedCreateFleetRequest(defaultExpectedFleetRequestValues),
Expand Down Expand Up @@ -462,7 +465,7 @@ describe('create runner with errors fail over to OnDemand', () => {

expect(mockEC2Client).toHaveReceivedCommandTimes(CreateFleetCommand, 2);

// first call with spot failuer
// first call with spot failoer
expect(mockEC2Client).toHaveReceivedNthCommandWith(1, CreateFleetCommand, {
...expectedCreateFleetRequest({
...defaultExpectedFleetRequestValues,
Expand All @@ -471,7 +474,7 @@ describe('create runner with errors fail over to OnDemand', () => {
}),
});

// second call with with OnDemand failback
// second call with with OnDemand fallback
expect(mockEC2Client).toHaveReceivedNthCommandWith(2, CreateFleetCommand, {
...expectedCreateFleetRequest({
...defaultExpectedFleetRequestValues,
Expand All @@ -481,13 +484,13 @@ describe('create runner with errors fail over to OnDemand', () => {
});
});

it('test InsufficientInstanceCapacity no failback.', async () => {
it('test InsufficientInstanceCapacity no fallback.', async () => {
await expect(
createRunner(createRunnerConfig({ ...defaultRunnerConfig, onDemandFailoverOnError: [] })),
).rejects.toBeInstanceOf(Error);
});

it('test InsufficientInstanceCapacity with mutlipte instances and fallback to on demand .', async () => {
it('test InsufficientInstanceCapacity with multiple instances and fallback to on demand .', async () => {
const instancesIds = ['i-123', 'i-456'];
createFleetMockWithWithOnDemandFallback(['InsufficientInstanceCapacity'], instancesIds);

Expand Down
9 changes: 7 additions & 2 deletions lambdas/functions/control-plane/src/aws/runners.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ async function processFleetResult(
{ data: fleet },
);
const errors = fleet.Errors?.flatMap((e) => e.ErrorCode || '') || [];
let failedCount = 0;

// Educated guess of errors that would make sense to retry based on the list
// https://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html
Expand Down Expand Up @@ -195,10 +196,10 @@ async function processFleetResult(
});
instances.push(...instancesOnDemand);
return instances;
} else if (errors.some((e) => scaleErrors.includes(e))) {
} else if ((failedCount = countScaleErrors(errors, scaleErrors)) > 0) {
logger.warn('Create fleet failed, ScaleError will be thrown to trigger retry for ephemeral runners.');
logger.debug('Create fleet failed.', { data: fleet.Errors });
throw new ScaleError('Failed to create instance, create fleet failed.');
throw new ScaleError('Failed to create instance, create fleet failed.', failedCount);
} else {
logger.warn('Create fleet failed, error not recognized as scaling error.', { data: fleet.Errors });
throw Error('Create fleet failed, no instance created.');
Expand All @@ -207,6 +208,10 @@ async function processFleetResult(
return instances;
}

function countScaleErrors(errors: string[], scaleErrors: string[]): number {
return errors.reduce((acc, e) => (scaleErrors.includes(e) ? acc + 1 : acc), 0);
}

async function getAmiIdOverride(runnerParameters: Runners.RunnerInputParameters): Promise<string | undefined> {
if (!runnerParameters.amiIdSsmParameterName) {
return undefined;
Expand Down
183 changes: 155 additions & 28 deletions lambdas/functions/control-plane/src/lambda.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ const sqsRecord: SQSRecord = {
},
awsRegion: '',
body: JSON.stringify(body),
eventSource: 'aws:SQS',
eventSource: 'aws:sqs',
eventSourceARN: '',
md5OfBody: '',
messageAttributes: {},
messageId: '',
messageId: 'abcd1234',
receiptHandle: '',
};

Expand Down Expand Up @@ -70,19 +70,33 @@ vi.mock('@aws-github-runner/aws-powertools-util');
vi.mock('@aws-github-runner/aws-ssm-util');

describe('Test scale up lambda wrapper.', () => {
it('Do not handle multiple record sets.', async () => {
await testInvalidRecords([sqsRecord, sqsRecord]);
it('Do not handle empty record sets.', async () => {
const sqsEventMultipleRecords: SQSEvent = {
Records: [],
};

await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow();
});

it('Do not handle empty record sets.', async () => {
await testInvalidRecords([]);
it('Ignores non-sqs event sources.', async () => {
const record = {
...sqsRecord,
eventSource: 'aws:non-sqs',
};

const sqsEventMultipleRecordsNonSQS: SQSEvent = {
Records: [record],
};

await expect(scaleUpHandler(sqsEventMultipleRecordsNonSQS, context)).resolves.not.toThrow();
expect(scaleUp).toHaveBeenCalledWith([]);
});

it('Scale without error should resolve.', async () => {
const mock = vi.fn(scaleUp);
mock.mockImplementation(() => {
return new Promise((resolve) => {
resolve();
resolve([]);
});
});
await expect(scaleUpHandler(sqsEvent, context)).resolves.not.toThrow();
Expand All @@ -95,37 +109,150 @@ describe('Test scale up lambda wrapper.', () => {
await expect(scaleUpHandler(sqsEvent, context)).resolves.not.toThrow();
});

it('Scale should be rejected', async () => {
it('Scale should create a batch failure message', async () => {
const error = new ScaleError('Scale should be rejected');
const mock = vi.fn() as MockedFunction<typeof scaleUp>;
mock.mockImplementation(() => {
return Promise.reject(error);
});
vi.mocked(scaleUp).mockImplementation(mock);
await expect(scaleUpHandler(sqsEvent, context)).rejects.toThrow(error);
await expect(scaleUpHandler(sqsEvent, context)).resolves.toEqual({
batchItemFailures: [{ itemIdentifier: sqsRecord.messageId }],
});
});
});

async function testInvalidRecords(sqsRecords: SQSRecord[]) {
const mock = vi.fn(scaleUp);
const logWarnSpy = vi.spyOn(logger, 'warn');
mock.mockImplementation(() => {
return new Promise((resolve) => {
resolve();
describe('Batch processing', () => {
beforeEach(() => {
vi.clearAllMocks();
});

const createMultipleRecords = (count: number, eventSource = 'aws:sqs'): SQSRecord[] => {
return Array.from({ length: count }, (_, i) => ({
...sqsRecord,
eventSource,
messageId: `message-${i}`,
body: JSON.stringify({
...body,
id: i + 1,
}),
}));
};

it('Should handle multiple SQS records in a single invocation', async () => {
const records = createMultipleRecords(3);
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.resolve([]));
vi.mocked(scaleUp).mockImplementation(mock);

await expect(scaleUpHandler(multiRecordEvent, context)).resolves.not.toThrow();
expect(scaleUp).toHaveBeenCalledWith(
expect.arrayContaining([
expect.objectContaining({ messageId: 'message-0' }),
expect.objectContaining({ messageId: 'message-1' }),
expect.objectContaining({ messageId: 'message-2' }),
]),
);
});

it('Should return batch item failures for rejected messages', async () => {
const records = createMultipleRecords(3);
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.resolve(['message-1', 'message-2']));
vi.mocked(scaleUp).mockImplementation(mock);

const result = await scaleUpHandler(multiRecordEvent, context);
expect(result).toEqual({
batchItemFailures: [{ itemIdentifier: 'message-1' }, { itemIdentifier: 'message-2' }],
});
});

it('Should filter out non-SQS event sources', async () => {
const sqsRecords = createMultipleRecords(2, 'aws:sqs');
const nonSqsRecords = createMultipleRecords(1, 'aws:sns');
const mixedEvent: SQSEvent = {
Records: [...sqsRecords, ...nonSqsRecords],
};

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.resolve([]));
vi.mocked(scaleUp).mockImplementation(mock);

await scaleUpHandler(mixedEvent, context);
expect(scaleUp).toHaveBeenCalledWith(
expect.arrayContaining([
expect.objectContaining({ messageId: 'message-0' }),
expect.objectContaining({ messageId: 'message-1' }),
]),
);
expect(scaleUp).not.toHaveBeenCalledWith(
expect.arrayContaining([expect.objectContaining({ messageId: 'message-2' })]),
);
});

it('Should sort messages by retry count', async () => {
const records = [
{
...sqsRecord,
messageId: 'high-retry',
body: JSON.stringify({ ...body, retryCounter: 5 }),
},
{
...sqsRecord,
messageId: 'low-retry',
body: JSON.stringify({ ...body, retryCounter: 1 }),
},
{
...sqsRecord,
messageId: 'no-retry',
body: JSON.stringify({ ...body }),
},
];
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation((messages) => {
// Verify messages are sorted by retry count (ascending)
expect(messages[0].messageId).toBe('no-retry');
expect(messages[1].messageId).toBe('low-retry');
expect(messages[2].messageId).toBe('high-retry');
return Promise.resolve([]);
});
vi.mocked(scaleUp).mockImplementation(mock);

await scaleUpHandler(multiRecordEvent, context);
});

it('Should return all failed messages when scaleUp throws non-ScaleError', async () => {
const records = createMultipleRecords(2);
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.reject(new Error('Generic error')));
vi.mocked(scaleUp).mockImplementation(mock);

const result = await scaleUpHandler(multiRecordEvent, context);
expect(result).toEqual({ batchItemFailures: [] });
});

it('Should throw when scaleUp throws ScaleError', async () => {
const records = createMultipleRecords(2);
const multiRecordEvent: SQSEvent = { Records: records };

const error = new ScaleError('Critical scaling error', 2);
const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.reject(error));
vi.mocked(scaleUp).mockImplementation(mock);

await expect(scaleUpHandler(multiRecordEvent, context)).resolves.toEqual({
batchItemFailures: [{ itemIdentifier: 'message-0' }, { itemIdentifier: 'message-1' }],
});
});
});
const sqsEventMultipleRecords: SQSEvent = {
Records: sqsRecords,
};

await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow();

expect(logWarnSpy).toHaveBeenCalledWith(
expect.stringContaining(
'Event ignored, only one record at the time can be handled, ensure the lambda batch size is set to 1.',
),
);
}
});

describe('Test scale down lambda wrapper.', () => {
it('Scaling down no error.', async () => {
Expand Down
Loading
Loading