diff --git a/contract-tests/README.md b/contract-tests/README.md new file mode 100644 index 00000000..f7746d27 --- /dev/null +++ b/contract-tests/README.md @@ -0,0 +1,55 @@ +## Introduction + +This directory contain contract tests that exist to prevent regressions. They cover: + +* [OpenTelemetry semantic conventions](https://github.com/open-telemetry/semantic-conventions/). +* Application Signals-specific attributes. + +## How it works? + +The tests present here rely on the auto-instrumentation of a sample application which will send telemetry signals to a mock collector. The tests will use the data collected by the mock collector to perform assertions and validate that the contracts are being respected. + +## Types of tested frameworks + +The frameworks and libraries that are tested in the contract tests should fall in the following categories (more can be added on demand): + +* http-servers - applications meant to test http servers (e.g. http module in node.js). +* aws-sdk - Applications meant to test the AWS SDK (e.g. AWS SDK for JavaScript v3). +* database-clients - Applications meant to test database clients (e.g. mysql2, Mongoose, Mongodb). + +When testing a framework, we will create a sample application. The sample applications are stored following this convention: `contract-tests/images/applications/`. + +## Adding tests for a new library or framework + +The steps to add a new test for a library or framework are: + +* Create a sample application. + * The sample application should be created in `contract-tests/images/applications/`. + * Implement a node.js application and create a `Dockerfile` to containerize the application +* Add a test class for the sample application. + * The test class should be created in `contract-tests/tests/amazon/`. + * The test class should extend `contract_test_base.py` + +## How to run the tests locally? + +Pre-requirements: + +* Have `docker` installed and running - verify by running the `docker` command. + +Steps: + +* From `aws-otel-js-instrumentation` dir, execute: + +```sh +# create a virtual environment in python for the tests +python3 -m venv venv +source venv/bin/activate +# build the instrumentation SDK +./scripts/build_and_install_distro.sh +# build the relevant images for sample app and build the contract tests +./scripts/set-up-contract-tests.sh +# run all the tests +pytest contract-tests/tests +# exit the virtual python environment +deactivate +``` diff --git a/contract-tests/images/applications/aws-sdk/Dockerfile b/contract-tests/images/applications/aws-sdk/Dockerfile new file mode 100644 index 00000000..c4f3c62e --- /dev/null +++ b/contract-tests/images/applications/aws-sdk/Dockerfile @@ -0,0 +1,22 @@ +# Use an official Node.js runtime as the base image +FROM node:20-alpine +#FROM node:20 + +# Set the working directory inside the container +WORKDIR /aws-sdk + +# Copy the relevant files +COPY ./dist/$DISTRO /aws-sdk +COPY ./contract-tests/images/applications/aws-sdk /aws-sdk + + +ARG DISTRO +# Install dependencies +RUN npm install +RUN npm install ./${DISTRO} + +# Expose the port the app runs on +EXPOSE 8080 + +# Run the app with nodejs auto instrumentation +CMD ["node", "--require", "@aws/aws-distro-opentelemetry-node-autoinstrumentation/register", "server.js"] diff --git a/contract-tests/images/applications/aws-sdk/package.json b/contract-tests/images/applications/aws-sdk/package.json new file mode 100644 index 00000000..25ade2b9 --- /dev/null +++ b/contract-tests/images/applications/aws-sdk/package.json @@ -0,0 +1,20 @@ +{ + "name": "aws-sdk-forwarder", + "version": "1.0.0", + "main": "server.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1", + "start": "node server.js" + }, + "author": "", + "license": "ISC", + "description": "", + "dependencies": { + "@aws-sdk/client-dynamodb": "^3.658.1", + "@aws-sdk/client-kinesis": "^3.658.1", + "@aws-sdk/client-s3": "^3.658.1", + "@aws-sdk/client-sqs": "^3.658.1", + "@smithy/node-http-handler": "^3.2.3", + "node-fetch": "^2.7.0" + } +} diff --git a/contract-tests/images/applications/aws-sdk/server.js b/contract-tests/images/applications/aws-sdk/server.js new file mode 100644 index 00000000..9f9ce9e7 --- /dev/null +++ b/contract-tests/images/applications/aws-sdk/server.js @@ -0,0 +1,493 @@ +// server.js +const http = require('http'); +const url = require('url'); +const fs = require('fs'); +const os = require('os'); +const ospath = require('path'); +const { NodeHttpHandler } =require('@smithy/node-http-handler'); + +const { S3Client, CreateBucketCommand, PutObjectCommand, GetObjectCommand } = require('@aws-sdk/client-s3'); +const { DynamoDBClient, CreateTableCommand, PutItemCommand } = require('@aws-sdk/client-dynamodb'); +const { SQSClient, CreateQueueCommand, SendMessageCommand, ReceiveMessageCommand } = require('@aws-sdk/client-sqs'); +const { KinesisClient, CreateStreamCommand, PutRecordCommand } = require('@aws-sdk/client-kinesis'); +const fetch = require('node-fetch'); + +const _PORT = 8080; +const _ERROR = 'error'; +const _FAULT = 'fault'; + +const _AWS_SDK_S3_ENDPOINT = process.env.AWS_SDK_S3_ENDPOINT; +const _AWS_SDK_ENDPOINT = process.env.AWS_SDK_ENDPOINT; +const _AWS_REGION = process.env.AWS_REGION; +const _FAULT_ENDPOINT = 'http://fault.test:8080'; + +process.env.AWS_ACCESS_KEY_ID = process.env.AWS_ACCESS_KEY_ID || 'testcontainers-localstack'; +process.env.AWS_SECRET_ACCESS_KEY = process.env.AWS_SECRET_ACCESS_KEY || 'testcontainers-localstack'; + +const noRetryConfig = { + maxAttempts: 0, + requestHandler: { + metadata: { handlerProtocol: 'http/1.1' }, + connectionTimeout: 3000, + socketTimeout: 3000, + }, +}; + +let statusCodeForFault = 200; + +async function prepareAwsServer() { + try { + // Initialize AWS SDK clients + const s3Client = new S3Client({ + endpoint: _AWS_SDK_S3_ENDPOINT, + region: _AWS_REGION, + forcePathStyle: true, + }); + + const ddbClient = new DynamoDBClient({ + endpoint: _AWS_SDK_ENDPOINT, + region: _AWS_REGION, + }); + + const sqsClient = new SQSClient({ + endpoint: _AWS_SDK_ENDPOINT, + region: _AWS_REGION, + }); + + const kinesisClient = new KinesisClient({ + endpoint: _AWS_SDK_ENDPOINT, + region: _AWS_REGION, + }); + + // Set up S3 + await s3Client.send( + new CreateBucketCommand({ + Bucket: 'test-put-object-bucket-name', + CreateBucketConfiguration: { LocationConstraint: _AWS_REGION }, + }) + ); + + await s3Client.send( + new CreateBucketCommand({ + Bucket: 'test-get-object-bucket-name', + CreateBucketConfiguration: { LocationConstraint: _AWS_REGION }, + }) + ); + + // Upload a file to S3 + const tempFileName = ospath.join(os.tmpdir(), 'tempfile'); + fs.writeFileSync(tempFileName, 'This is temp file for S3 upload'); + const fileStream = fs.createReadStream(tempFileName); + await s3Client.send( + new PutObjectCommand({ + Bucket: 'test-get-object-bucket-name', + Key: 'test_object', + Body: fileStream, + }) + ); + fs.unlinkSync(tempFileName); + + // Set up DynamoDB + await ddbClient.send( + new CreateTableCommand({ + TableName: 'put_test_table', + KeySchema: [{ AttributeName: 'id', KeyType: 'HASH' }], + AttributeDefinitions: [{ AttributeName: 'id', AttributeType: 'S' }], + BillingMode: 'PAY_PER_REQUEST', + }) + ); + + // Set up SQS + await sqsClient.send( + new CreateQueueCommand({ + QueueName: 'test_put_get_queue', + }) + ); + + // Set up Kinesis + await kinesisClient.send( + new CreateStreamCommand({ + StreamName: 'test_stream', + ShardCount: 1, + }) + ); + } catch (error) { + console.error('Unexpected exception occurred', error); + } +} + +const server = http.createServer(async (req, res) => { + const parsedUrl = url.parse(req.url); + const pathName = parsedUrl.pathname; + + if (req.method === 'GET') { + await handleGetRequest(req, res, pathName); + } else if (req.method === 'POST') { + await handlePostRequest(req, res, pathName); + } else if (req.method === 'PUT') { + await handlePutRequest(req, res, pathName); + } else { + res.writeHead(405); + res.end(); + } +}); + +async function handleGetRequest(req, res, path) { + if (path.includes('s3')) { + await handleS3Request(req, res, path); + } else if (path.includes('ddb')) { + await handleDdbRequest(req, res, path); + } else if (path.includes('sqs')) { + await handleSqsRequest(req, res, path); + } else if (path.includes('kinesis')) { + await handleKinesisRequest(req, res, path); + } else { + res.writeHead(404); + res.end(); + } +} + +// this can be served as the fake AWS service to generate fault responses +async function handlePostRequest(req, res, path) { + res.writeHead(statusCodeForFault); + res.end(); +} + +// this can be served as the fake AWS service to generate fault responses +async function handlePutRequest(req, res, path) { + res.writeHead(statusCodeForFault); + res.end(); +} + +async function handleS3Request(req, res, path) { + const s3Client = new S3Client({ + endpoint: _AWS_SDK_S3_ENDPOINT, + region: _AWS_REGION, + forcePathStyle: true, + }); + + if (path.includes(_ERROR)) { + res.statusCode = 400; + try { + // trigger error case with an invalid bucket name + await s3Client.send( + new CreateBucketCommand({ + Bucket: '-', + }) + ); + } catch (err) { + console.log('Expected exception occurred', err); + } + res.end(); + } else if (path.includes(_FAULT)) { + res.statusCode = 500; + // save the status code so that the current server will response correctly + // when the faultS3Client connect to it + statusCodeForFault = 500; + try { + const faultS3Client = new S3Client({ + endpoint: _FAULT_ENDPOINT, + region: _AWS_REGION, + forcePathStyle: true, + maxAttempts: 0, + requestHandler: { + metadata: { handlerProtocol: 'http/1.1' }, + connectionTimeout: 3000, + socketTimeout: 3000, + }, + }); + await faultS3Client.send( + new CreateBucketCommand({ + Bucket: 'valid-bucket-name', + }) + ); + } catch (err) { + console.log('Expected exception occurred', err); + } + res.end(); + } else if (path.includes('createbucket/create-bucket')) { + try { + await s3Client.send( + new CreateBucketCommand({ + Bucket: 'test-bucket-name', + CreateBucketConfiguration: { LocationConstraint: _AWS_REGION }, + }) + ); + res.statusCode = 200; + } catch (err) { + console.log('Error creating bucket', err); + res.statusCode = 500; + } + res.end(); + } else if (path.includes('createobject/put-object/some-object')) { + try { + const tempFileName = ospath.join(os.tmpdir(), 'tempfile'); + fs.writeFileSync(tempFileName, 'This is temp file for S3 upload'); + const fileStream = fs.createReadStream(tempFileName); + await s3Client.send( + new PutObjectCommand({ + Bucket: 'test-put-object-bucket-name', + Key: 'test_object', + Body: fileStream, + }) + ); + fs.unlinkSync(tempFileName); + res.statusCode = 200; + } catch (err) { + console.log('Error uploading file', err); + res.statusCode = 500; + } + res.end(); + } else if (path.includes('getobject/get-object/some-object')) { + try { + const data = await s3Client.send( + new GetObjectCommand({ + Bucket: 'test-get-object-bucket-name', + Key: 'test_object', + }) + ); + res.statusCode = 200; + res.end(); + } catch (err) { + console.log('Error getting object', err); + res.statusCode = 500; + res.end(); + } + } else { + res.statusCode = 404; + res.end(); + } +} + +async function handleDdbRequest(req, res, path) { + const ddbClient = new DynamoDBClient({ + endpoint: _AWS_SDK_ENDPOINT, + region: _AWS_REGION, + }); + + if (path.includes(_ERROR)) { + res.statusCode = 400; + try { + const item = { id: { S: '1' } }; + await ddbClient.send( + new PutItemCommand({ + TableName: 'invalid_table', + Item: item, + }) + ); + } catch (err) { + console.log('Expected exception occurred', err); + } + res.end(); + } else if (path.includes(_FAULT)) { + res.statusCode = 500; + statusCodeForFault = 500; + try { + const faultDdbClient = new DynamoDBClient({ + endpoint: _FAULT_ENDPOINT, + region: _AWS_REGION, + ...noRetryConfig, + }); + const item = { id: { S: '1' } }; + await faultDdbClient.send( + new PutItemCommand({ + TableName: 'invalid_table', + Item: item, + }) + ); + } catch (err) { + console.log('Expected exception occurred', err); + } + res.end(); + } else if (path.includes('createtable/some-table')) { + try { + await ddbClient.send( + new CreateTableCommand({ + TableName: 'test_table', + KeySchema: [{ AttributeName: 'id', KeyType: 'HASH' }], + AttributeDefinitions: [{ AttributeName: 'id', AttributeType: 'S' }], + BillingMode: 'PAY_PER_REQUEST', + }) + ); + res.statusCode = 200; + } catch (err) { + console.log('Error creating table', err); + res.statusCode = 500; + } + res.end(); + } else if (path.includes('putitem/putitem-table/key')) { + try { + const item = { id: { S: '1' } }; + await ddbClient.send( + new PutItemCommand({ + TableName: 'put_test_table', + Item: item, + }) + ); + res.statusCode = 200; + } catch (err) { + console.log('Error putting item', err); + res.statusCode = 500; + } + res.end(); + } else { + res.statusCode = 404; + res.end(); + } +} + +async function handleSqsRequest(req, res, path) { + const sqsClient = new SQSClient({ + endpoint: _AWS_SDK_ENDPOINT, + region: _AWS_REGION, + }); + + if (path.includes(_ERROR)) { + res.statusCode = 400; + try { + await sqsClient.send( + new SendMessageCommand({ + QueueUrl: 'http://error.test:8080/000000000000/sqserror', + MessageBody: _ERROR, + }) + ); + } catch (err) { + console.log('Expected exception occurred', err); + } + res.end(); + } else if (path.includes(_FAULT)) { + res.statusCode = 500; + statusCodeForFault = 500; + try { + const faultSqsClient = new SQSClient({ + endpoint: _FAULT_ENDPOINT, + region: _AWS_REGION, + ...noRetryConfig, + }); + await faultSqsClient.send( + new CreateQueueCommand({ + QueueName: 'invalid_test', + }) + ); + } catch (err) { + console.log('Expected exception occurred', err); + } + res.end(); + } else if (path.includes('createqueue/some-queue')) { + try { + await sqsClient.send( + new CreateQueueCommand({ + QueueName: 'test_queue', + }) + ); + res.statusCode = 200; + } catch (err) { + console.log('Error creating queue', err); + res.statusCode = 500; + } + res.end(); + } else if (path.includes('publishqueue/some-queue')) { + try { + await sqsClient.send( + new SendMessageCommand({ + QueueUrl: 'http://localstack:4566/000000000000/test_put_get_queue', + MessageBody: 'test_message', + }) + ); + res.statusCode = 200; + } catch (err) { + console.log('Error sending message', err); + res.statusCode = 500; + } + res.end(); + } else if (path.includes('consumequeue/some-queue')) { + try { + await sqsClient.send( + new ReceiveMessageCommand({ + QueueUrl: 'http://localstack:4566/000000000000/test_put_get_queue', + MaxNumberOfMessages: 1, + }) + ); + res.statusCode = 200; + } catch (err) { + console.log('Error receiving message', err); + res.statusCode = 500; + } + res.end(); + } else { + res.statusCode = 404; + res.end(); + } +} + +async function handleKinesisRequest(req, res, path) { + const kinesisClient = new KinesisClient({ + endpoint: _AWS_SDK_ENDPOINT, + region: _AWS_REGION, + }); + + if (path.includes(_ERROR)) { + res.statusCode = 400; + try { + await kinesisClient.send( + new PutRecordCommand({ + StreamName: 'invalid_stream', + Data: Buffer.from('test'), + PartitionKey: 'partition_key', + }) + ); + } catch (err) { + console.log('Expected exception occurred', err); + } + res.end(); + } else if (path.includes(_FAULT)) { + res.statusCode = 500; + statusCodeForFault = 500; + try { + const faultKinesisClient = new KinesisClient({ + endpoint: _FAULT_ENDPOINT, + region: _AWS_REGION, + requestHandler: new NodeHttpHandler({ + connectionTimeout: 3000, + socketTimeout: 3000, + }), + maxAttempts: 0, + }); + await faultKinesisClient.send( + new PutRecordCommand({ + StreamName: 'test_stream', + Data: Buffer.from('test'), + PartitionKey: 'partition_key', + }) + ); + } catch (err) { + console.log('Expected exception occurred', err); + } + res.end(); + } else if (path.includes('putrecord/my-stream')) { + try { + await kinesisClient.send( + new PutRecordCommand({ + StreamName: 'test_stream', + Data: Buffer.from('test'), + PartitionKey: 'partition_key', + }) + ); + res.statusCode = 200; + } catch (err) { + console.log('Error putting record', err); + res.statusCode = 500; + } + res.end(); + } else { + res.statusCode = 404; + res.end(); + } +} + +prepareAwsServer().then(() => { + server.listen(_PORT, '0.0.0.0', () => { + console.log('Server is listening on port', _PORT); + console.log('Ready'); + }); +}); diff --git a/contract-tests/images/applications/http/Dockerfile b/contract-tests/images/applications/http/Dockerfile new file mode 100644 index 00000000..70a52dea --- /dev/null +++ b/contract-tests/images/applications/http/Dockerfile @@ -0,0 +1,22 @@ +# Use an official Node.js runtime as the base image +FROM node:20-alpine +#FROM node:20 + +# Set the working directory inside the container +WORKDIR /http + +# Copy the relevant files +COPY ./dist/$DISTRO /http +COPY ./contract-tests/images/applications/http /http + + +ARG DISTRO +# Install dependencies +RUN npm install +RUN npm install ./${DISTRO} + +# Expose the port the app runs on +EXPOSE 8080 + +# Run the app with nodejs auto instrumentation +CMD ["node", "--require", "@aws/aws-distro-opentelemetry-node-autoinstrumentation/register", "server.js"] diff --git a/contract-tests/images/applications/http/package.json b/contract-tests/images/applications/http/package.json new file mode 100644 index 00000000..231edf01 --- /dev/null +++ b/contract-tests/images/applications/http/package.json @@ -0,0 +1,14 @@ +{ + "name": "http-server", + "version": "1.0.0", + "description": "A simple Node.js http server", + "main": "server.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1", + "start": "node server.js" + }, + "author": "", + "license": "ISC", + "dependencies": { + } +} diff --git a/contract-tests/images/applications/http/server.js b/contract-tests/images/applications/http/server.js new file mode 100644 index 00000000..b330254f --- /dev/null +++ b/contract-tests/images/applications/http/server.js @@ -0,0 +1,59 @@ +// server.js + +const http = require('http'); + +const PORT = 8080; +const NETWORK_ALIAS = 'backend'; +const SUCCESS = 'success'; +const ERROR = 'error'; +const FAULT = 'fault'; + +const server = http.createServer((req, res) => { + const method = req.method; + const url = req.url; + + // Helper function to check if a substring is in the path + const inPath = (subPath) => url.includes(subPath); + + if (inPath(`/${NETWORK_ALIAS}`)) { + let statusCode; + if (inPath(`/${SUCCESS}`)) { + statusCode = 200; + } else if (inPath(`/${ERROR}`)) { + statusCode = 400; + } else if (inPath(`/${FAULT}`)) { + statusCode = 500; + } else { + statusCode = 404; + } + res.writeHead(statusCode); + res.end(); + } else { + // Forward the request to http://backend:8080/backend{original_path} + const options = { + hostname: NETWORK_ALIAS, + port: PORT, + // port: 9090, + path: `/${NETWORK_ALIAS}${url}`, + method: method, + headers: req.headers, + timeout: 20000, // 20 seconds timeout + }; + + const proxyReq = http.request(options, (proxyRes) => { + res.writeHead(proxyRes.statusCode, proxyRes.headers); + proxyRes.pipe(res, { end: true }); + }); + + proxyReq.on('error', (err) => { + res.writeHead(500); + res.end('Proxy error'); + }); + + req.pipe(proxyReq, { end: true }); + } +}); + +server.listen(PORT, '0.0.0.0', () => { + console.log('Ready'); +}); diff --git a/contract-tests/images/applications/mongodb/Dockerfile b/contract-tests/images/applications/mongodb/Dockerfile new file mode 100644 index 00000000..968c8257 --- /dev/null +++ b/contract-tests/images/applications/mongodb/Dockerfile @@ -0,0 +1,21 @@ +# Use an official Node.js runtime as the base image +FROM node:20-alpine + +# Set the working directory inside the container +WORKDIR /mongodb + +# Copy the relevant files +COPY ./dist/$DISTRO /mongodb +COPY ./contract-tests/images/applications/mongodb /mongodb + + +ARG DISTRO +# Install dependencies +RUN npm install +RUN npm install ./${DISTRO} + +# Expose the port the app runs on +EXPOSE 8080 + +# Run the app with nodejs auto instrumentation +CMD ["node", "--require", "@aws/aws-distro-opentelemetry-node-autoinstrumentation/register", "server.js"] diff --git a/contract-tests/images/applications/mongodb/package.json b/contract-tests/images/applications/mongodb/package.json new file mode 100644 index 00000000..d540b5e1 --- /dev/null +++ b/contract-tests/images/applications/mongodb/package.json @@ -0,0 +1,14 @@ +{ + "name": "mongodb-forwarder", + "version": "1.0.0", + "main": "server.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "author": "", + "license": "ISC", + "description": "", + "dependencies": { + "mongodb": "^6.9.0" + } +} diff --git a/contract-tests/images/applications/mongodb/server.js b/contract-tests/images/applications/mongodb/server.js new file mode 100644 index 00000000..d3d46a78 --- /dev/null +++ b/contract-tests/images/applications/mongodb/server.js @@ -0,0 +1,164 @@ +// Import necessary modules +const http = require('http'); +const { MongoClient } = require('mongodb'); +const process = require('process'); +const url = require('url'); // For parsing URL parameters + +// Constants +const PORT = 8080; +const FIND_DOCUMENT = 'find'; +const INSERT_DOCUMENT = 'insert_document'; +const DELETE_DOCUMENT = 'delete_document'; +const UPDATE_DOCUMENT = 'update_document'; +const FAULT = 'fault'; + +// Environment variables for database connection +const DB_HOST = process.env.DB_HOST; +const DB_USER = process.env.DB_USER; +const DB_PASS = process.env.DB_PASS; +const DB_NAME = process.env.DB_NAME; + +// MongoDB connection URI +const mongoURI = `mongodb://${DB_USER}:${DB_PASS}@${DB_HOST}:27017/${DB_NAME}?authSource=admin`; + +console.log("Connect to MongoDB using " + mongoURI); + +// Create a new MongoClient +const client = new MongoClient(mongoURI, { useUnifiedTopology: true }); + +// Function to prepare the database server +async function prepareDbServer() { + try { + // Connect to the MongoDB server + await client.connect(); + console.log('MongoDB connection established'); + + const db = client.db(DB_NAME); + const collection = db.collection('employees'); + + // Check if the collection exists + const collections = await db.listCollections({ name: 'employees' }).toArray(); + if (collections.length === 0) { + // Collection does not exist, create it and insert a document + await collection.insertOne({ id: 0, name: 'to-be-updated' }); + await collection.insertOne({ id: 1, name: 'A' }); + console.log('Employee collection created and document inserted'); + } else { + console.log('Employee collection already exists'); + } + // Start the server after preparing the database + startServer(); + } catch (err) { + console.error('Error preparing database server:', err); + } +} + +// Function to start the HTTP server +function startServer() { + const server = http.createServer((req, res) => { + // Handle the request + if (req.method === 'GET') { + (async () => { + try { + await handleGetRequest(req, res); + } catch (err) { + console.error('Error in request handler:', err); + res.statusCode = 500; + res.end(); + } + })(); + } else { + res.statusCode = 405; // Method Not Allowed + res.end(); + } + }); + + server.listen(PORT, '0.0.0.0', () => { + console.log(`Server is listening on port ${PORT}`); + console.log('Ready'); + }); +} + +// Function to handle GET requests +async function handleGetRequest(req, res) { + let statusCode = 200; + const parsedUrl = url.parse(req.url, true); // Parse URL and query parameters + const pathname = parsedUrl.pathname; + + try { + const db = client.db(DB_NAME); + const collection = db.collection('employees'); + + if (pathname.includes(FIND_DOCUMENT)) { + // Retrieve documents + const employees = await collection.find({}).toArray(); + statusCode = 200; + res.statusCode = statusCode; + res.setHeader('Content-Type', 'application/json'); + res.end(JSON.stringify(employees)); + } else if (pathname.includes(INSERT_DOCUMENT)) { + // Insert a new document into the employee collection + // Extract 'id' and 'name' from query parameters + const id = parseInt(parsedUrl.query.id) || 2; + const name = parsedUrl.query.name || 'B'; + + await collection.insertOne({ id: id, name: name }); + console.log('New employee inserted'); + statusCode = 200; + res.statusCode = statusCode; + res.end(); + } else if (pathname.includes(DELETE_DOCUMENT)) { + // Delete employee with id = 1 + await collection.deleteOne({ id: 1 }); + console.log('Employee with id=1 deleted'); + statusCode = 200; + res.statusCode = statusCode; + res.end(); + } else if (pathname.includes(UPDATE_DOCUMENT)) { + // Update an existing employee entry + const id = 0; + const name = 'updated_name'; + + const result = await collection.findOneAndUpdate( + { id: id }, // Find the employee by id + { $set: { name: name } }, // Update the name field + { returnOriginal: false, upsert: true } // Return the updated document, create it if it doesn't exist + ); + + if (result) { + console.log(`Employee with id=${id} updated to name=${name}`); + statusCode = 200; + res.setHeader('Content-Type', 'application/json'); + res.end(JSON.stringify(result.value)); // Return updated employee as response + } else { + console.log('Employee not found'); + statusCode = 404; + res.statusCode = statusCode; + res.end(); + } + } else if (pathname.includes(FAULT)) { + // Try to execute an invalid MongoDB command to trigger an error + try { + await db.command({ invalidCommand: 1 }); + statusCode = 200; + } catch (err) { + console.error('Expected Exception with Invalid Command occurred:', err); + statusCode = 500; + } + res.statusCode = statusCode; + res.end(); + } else { + statusCode = 404; + res.statusCode = statusCode; + res.end(); + } + } catch (err) { + console.error('Error handling request:', err); + statusCode = 500; + res.statusCode = statusCode; + res.end(); + } +} + +// Start the database preparation and server +prepareDbServer(); diff --git a/contract-tests/images/applications/mongoose/Dockerfile b/contract-tests/images/applications/mongoose/Dockerfile new file mode 100644 index 00000000..ee6a8a79 --- /dev/null +++ b/contract-tests/images/applications/mongoose/Dockerfile @@ -0,0 +1,22 @@ +# Use an official Node.js runtime as the base image +FROM node:20-alpine +#FROM node:20 + +# Set the working directory inside the container +WORKDIR /mongoose + +# Copy the relevant files +COPY ./dist/$DISTRO /mongoose +COPY ./contract-tests/images/applications/mongoose /mongoose + + +ARG DISTRO +# Install dependencies +RUN npm install +RUN npm install ./${DISTRO} + +# Expose the port the app runs on +EXPOSE 8080 + +# Run the app with nodejs auto instrumentation +CMD ["node", "--require", "@aws/aws-distro-opentelemetry-node-autoinstrumentation/register", "server.js"] diff --git a/contract-tests/images/applications/mongoose/package.json b/contract-tests/images/applications/mongoose/package.json new file mode 100644 index 00000000..566790f5 --- /dev/null +++ b/contract-tests/images/applications/mongoose/package.json @@ -0,0 +1,15 @@ +{ + "name": "mongoose-forwarder", + "version": "1.0.0", + "main": "server.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1", + "start": "node server.js" + }, + "author": "", + "license": "ISC", + "description": "", + "dependencies": { + "mongoose": "^8.7.0" + } +} diff --git a/contract-tests/images/applications/mongoose/server.js b/contract-tests/images/applications/mongoose/server.js new file mode 100644 index 00000000..c39dcf81 --- /dev/null +++ b/contract-tests/images/applications/mongoose/server.js @@ -0,0 +1,171 @@ +// Import necessary modules +const http = require('http'); +const mongoose = require('mongoose'); +const process = require('process'); +const url = require('url'); // For parsing URL parameters + +// Constants +const PORT = 8080; +const FIND_DOCUMENT = 'find'; +const INSERT_DOCUMENT = 'insert_document'; +const DELETE_DOCUMENT = 'delete_document'; +const UPDATE_DOCUMENT = 'update_document'; +const FAULT = 'fault'; + +// Environment variables for database connection +const DB_HOST = process.env.DB_HOST; +const DB_USER = process.env.DB_USER; +const DB_PASS = process.env.DB_PASS; +const DB_NAME = process.env.DB_NAME; + +// MongoDB connection URI +const mongoURI = `mongodb://${DB_USER}:${DB_PASS}@${DB_HOST}:27017/${DB_NAME}?authSource=admin`; + +console.log("Connect to mongodb using " + mongoURI); +// Connect to MongoDB +mongoose.connect(mongoURI) + .then(() => { + console.log('MongoDB connection established'); + // Prepare the database server + prepareDbServer(); + }) + .catch(err => { + console.error('MongoDB connection error:', err); + }); + +// Define the Employee schema and model +const employeeSchema = new mongoose.Schema({ + id: Number, + name: String +}); + +const Employee = mongoose.model('Employee', employeeSchema); + +// Function to prepare the database server +async function prepareDbServer() { + try { + const collinfo = await mongoose.connection.db.listCollections({ name: 'employees' }).next(); + if (!collinfo) { + // Collection does not exist, create it and insert a document + const employee = new Employee({ id: 1, name: 'A' }); + await employee.save(); + console.log('Employee collection created and document inserted'); + } else { + console.log('Employee collection already exists'); + } + // Start the server after preparing the database + startServer(); + } catch (err) { + console.error('Error preparing database server:', err); + } +} + +// Function to start the HTTP server +function startServer() { + const server = http.createServer((req, res) => { + // Handle the request + if (req.method === 'GET') { + (async () => { + try { + await handleGetRequest(req, res); + } catch (err) { + console.error('Error in request handler:', err); + res.statusCode = 500; + res.end(); + } + })(); + } else { + res.statusCode = 405; // Method Not Allowed + res.end(); + } + }); + + server.listen(PORT, '0.0.0.0', () => { + console.log(`Server is listening on port ${PORT}`); + console.log('Ready'); + }); +} + +// Function to handle GET requests +async function handleGetRequest(req, res) { + let statusCode = 200; + const parsedUrl = url.parse(req.url, true); // Parse URL and query parameters + const pathname = parsedUrl.pathname; + + try { + if (pathname.includes(FIND_DOCUMENT)) { + // Use find operation to retrieve documents + const employees = await Employee.find({}); + statusCode = 200; + res.statusCode = statusCode; + res.setHeader('Content-Type', 'application/json'); + res.end(JSON.stringify(employees)); + } else if (pathname.includes(INSERT_DOCUMENT)) { + // Insert a new document into the employee collection + // Extract 'id' and 'name' from query parameters + const id = parseInt(parsedUrl.query.id) || 2; + const name = parsedUrl.query.name || 'B'; + + const newEmployee = new Employee({ id: id, name: name }); + await newEmployee.save(); + console.log('New employee inserted'); + statusCode = 200; + res.statusCode = statusCode; + res.end(); + } else if (pathname.includes(DELETE_DOCUMENT)) { + // Delete employee with id = 1 + await Employee.deleteOne({ id: 1 }); + console.log('Employee with id=1 deleted'); + statusCode = 200; + res.statusCode = statusCode; + res.end(); + } else if (pathname.includes(UPDATE_DOCUMENT)) { + try { + let id = 1; + let name = "updated_name" + const updatedEmployee = await Employee.findOneAndUpdate( + { id: id }, // Find the employee by id + { $set: { name: name } }, + { new: true, upsert: true } + ); + + if (updatedEmployee) { + console.log(`Employee with id=${id} updated to name=${name}`); + statusCode = 200; + } else { + console.log('Employee not found'); + statusCode = 404; + } + } catch (err) { + console.error('Error updating employee:', err); + statusCode = 500; + } + + res.statusCode = statusCode; + res.end(); + + } else if (pathname.includes(FAULT)) { + // We don't test this fault in our contract test because the span attributes are + // very different from other operations. + // Try to execute an invalid MongoDB command to trigger an error + try { + await mongoose.connection.db.command({ invalidCommand: 1 }); + statusCode = 200; + } catch (err) { + console.error('Expected Exception with Invalid Command occurred:', err); + statusCode = 500; + } + res.statusCode = statusCode; + res.end(); + } else { + statusCode = 404; + res.statusCode = statusCode; + res.end(); + } + } catch (err) { + console.error('Error handling request:', err); + statusCode = 500; + res.statusCode = statusCode; + res.end(); + } +} diff --git a/contract-tests/images/applications/mysql2/Dockerfile b/contract-tests/images/applications/mysql2/Dockerfile new file mode 100644 index 00000000..2e6ff46d --- /dev/null +++ b/contract-tests/images/applications/mysql2/Dockerfile @@ -0,0 +1,21 @@ +# Use an official Node.js runtime as the base image +FROM node:20-alpine + +# Set the working directory inside the container +WORKDIR /mysql2 + +# Copy the relevant files +COPY ./dist/$DISTRO /mysql2 +COPY ./contract-tests/images/applications/mysql2 /mysql2 + + +ARG DISTRO +# Install dependencies +RUN npm install +RUN npm install ./${DISTRO} + +# Expose the port the app runs on +EXPOSE 8080 + +# Run the app with nodejs auto instrumentation +CMD ["node", "--require", "@aws/aws-distro-opentelemetry-node-autoinstrumentation/register", "server.js"] diff --git a/contract-tests/images/applications/mysql2/package.json b/contract-tests/images/applications/mysql2/package.json new file mode 100644 index 00000000..edf0fcda --- /dev/null +++ b/contract-tests/images/applications/mysql2/package.json @@ -0,0 +1,15 @@ +{ + "name": "mysql-forwarder", + "version": "1.0.0", + "main": "server.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1", + "start": "node server.js" + }, + "author": "", + "license": "ISC", + "description": "", + "dependencies": { + "mysql2": "^3.11.3" + } +} diff --git a/contract-tests/images/applications/mysql2/server.js b/contract-tests/images/applications/mysql2/server.js new file mode 100644 index 00000000..cbd80201 --- /dev/null +++ b/contract-tests/images/applications/mysql2/server.js @@ -0,0 +1,101 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +const http = require('http'); +const mysql = require('mysql2/promise'); + +const SELECT = 'select'; +const CREATE_DATABASE = 'create_database'; +const DROP_TABLE = 'drop_table'; +const ERROR = 'error'; +const FAULT = 'fault'; +const PORT = 8080; + +const DB_HOST = process.env.DB_HOST; +const DB_USER = process.env.DB_USER; +const DB_PASS = process.env.DB_PASS; +const DB_NAME = process.env.DB_NAME; + +const pool = mysql.createPool({ + host: DB_HOST, + user: DB_USER, + password: DB_PASS, + database: DB_NAME, +}); + +async function prepareDbServer() { + try { + const connection = await pool.getConnection(); + + const [results] = await connection.execute("SHOW TABLES LIKE 'employee'"); + + if (results.length === 0) { + await connection.execute("CREATE TABLE employee (id int, name varchar(255))"); + await connection.execute("INSERT INTO employee (id, name) VALUES (1, 'A')"); + } + + connection.release(); + } catch (err) { + console.error('Error in prepareDbServer:', err); + throw err; + } +} + +async function main() { + try { + await prepareDbServer(); + + const server = http.createServer(async (req, res) => { + if (req.method === 'GET') { + let statusCode = 200; + const url = req.url; + let connection; + + try { + connection = await pool.getConnection(); + + if (url.includes(SELECT)) { + const [results] = await connection.execute("SELECT count(*) FROM employee"); + statusCode = results.length === 1 ? 200 : 500; + } else if (url.includes(DROP_TABLE)) { + await connection.execute("DROP TABLE IF EXISTS test_table"); + statusCode = 200; + } else if (url.includes(CREATE_DATABASE)) { + await connection.execute("CREATE DATABASE test_database"); + statusCode = 200; + } else if (url.includes(FAULT)) { + try { + await connection.execute("SELECT DISTINCT id, name FROM invalid_table"); + statusCode = 200; + } catch (err) { + console.error("Expected Exception with Invalid SQL occurred:", err); + statusCode = 500; + } + } else { + statusCode = 404; + } + } catch (err) { + console.error('Error handling request:', err); + statusCode = 500; + } finally { + if (connection) connection.release(); + } + + res.writeHead(statusCode); + res.end(); + } else { + res.writeHead(405); // Method Not Allowed + res.end(); + } + }); + + server.listen(PORT, () => { + console.log('Ready'); + }); + + } catch (err) { + console.error('Error in main:', err); + } +} + +main(); diff --git a/contract-tests/images/mock-collector/Dockerfile b/contract-tests/images/mock-collector/Dockerfile new file mode 100644 index 00000000..bad4f42b --- /dev/null +++ b/contract-tests/images/mock-collector/Dockerfile @@ -0,0 +1,9 @@ +FROM python:3.10 +WORKDIR /mock-collector +COPY . /mock-collector + +ENV PIP_ROOT_USER_ACTION=ignore +RUN pip install --upgrade pip && pip install -r requirements.txt + +# Without `-u`, logs will be buffered and `wait_for_logs` will never return. +CMD ["python", "-u", "./mock_collector_server.py"] \ No newline at end of file diff --git a/contract-tests/images/mock-collector/README.md b/contract-tests/images/mock-collector/README.md new file mode 100644 index 00000000..1ed8572e --- /dev/null +++ b/contract-tests/images/mock-collector/README.md @@ -0,0 +1,11 @@ +### Overview + +MockCollector mimics the behaviour of the actual OTEL collector, but stores export requests to be retrieved by contract tests. + +### Protos + +To build protos: + +1. Run `pip install grpcio grpcio-tools` +2. Change directory to `aws-otel-python-instrumentation/contract-tests/images/mock-collector/` +3. Run: `python -m grpc_tools.protoc -I./protos --python_out=. --pyi_out=. --grpc_python_out=. ./protos/mock_collector_service.proto` diff --git a/contract-tests/images/mock-collector/mock_collector_client.py b/contract-tests/images/mock-collector/mock_collector_client.py new file mode 100644 index 00000000..21389cea --- /dev/null +++ b/contract-tests/images/mock-collector/mock_collector_client.py @@ -0,0 +1,143 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from datetime import datetime, timedelta +from logging import Logger, getLogger +from time import sleep +from typing import Callable, List, Set, TypeVar + +from google.protobuf.internal.containers import RepeatedScalarFieldContainer +from grpc import Channel, insecure_channel +from mock_collector_service_pb2 import ( + ClearRequest, + GetMetricsRequest, + GetMetricsResponse, + GetTracesRequest, + GetTracesResponse, +) +from mock_collector_service_pb2_grpc import MockCollectorServiceStub + +from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest +from opentelemetry.proto.metrics.v1.metrics_pb2 import Metric, ResourceMetrics, ScopeMetrics +from opentelemetry.proto.trace.v1.trace_pb2 import ResourceSpans, ScopeSpans, Span + +_logger: Logger = getLogger(__name__) +_TIMEOUT_DELAY: timedelta = timedelta(seconds=20) +_WAIT_INTERVAL_SEC: float = 0.1 +T: TypeVar = TypeVar("T") + + +class ResourceScopeSpan: + """Data class used to correlate resources, scope and telemetry signals. + + Correlate resource, scope and span + """ + + def __init__(self, resource_spans: ResourceSpans, scope_spans: ScopeSpans, span: Span): + self.resource_spans: ResourceSpans = resource_spans + self.scope_spans: ScopeSpans = scope_spans + self.span: Span = span + + +class ResourceScopeMetric: + """Data class used to correlate resources, scope and telemetry signals. + + Correlate resource, scope and metric + """ + + def __init__(self, resource_metrics: ResourceMetrics, scope_metrics: ScopeMetrics, metric: Metric): + self.resource_metrics: ResourceMetrics = resource_metrics + self.scope_metrics: ScopeMetrics = scope_metrics + self.metric: Metric = metric + + +class MockCollectorClient: + """The mock collector client is used to interact with the Mock collector image, used in the tests.""" + + def __init__(self, mock_collector_address: str, mock_collector_port: str): + channel: Channel = insecure_channel(f"{mock_collector_address}:{mock_collector_port}") + self.client: MockCollectorServiceStub = MockCollectorServiceStub(channel) + + def clear_signals(self) -> None: + """Clear all the signals in the backend collector""" + self.client.clear(ClearRequest()) + + def get_traces(self) -> List[ResourceScopeSpan]: + """Get all traces that are currently stored in the collector + + Returns: + List of `ResourceScopeSpan` which is essentially a flat list containing all the spans and their related + scope and resources. + """ + + def get_export() -> List[ExportTraceServiceRequest]: + response: GetTracesResponse = self.client.get_traces(GetTracesRequest()) + serialized_traces: RepeatedScalarFieldContainer[bytes] = response.traces + return list(map(ExportTraceServiceRequest.FromString, serialized_traces)) + + def wait_condition(exported: List[ExportTraceServiceRequest], current: List[ExportTraceServiceRequest]) -> bool: + return 0 < len(exported) == len(current) + + exported_traces: List[ExportTraceServiceRequest] = _wait_for_content(get_export, wait_condition) + spans: List[ResourceScopeSpan] = [] + for exported_trace in exported_traces: + for resource_span in exported_trace.resource_spans: + for scope_span in resource_span.scope_spans: + for span in scope_span.spans: + spans.append(ResourceScopeSpan(resource_span, scope_span, span)) + return spans + + def get_metrics(self, present_metrics: Set[str]) -> List[ResourceScopeMetric]: + """Get all metrics that are currently stored in the mock collector. + + Returns: + List of `ResourceScopeMetric` which is a flat list containing all metrics and their related scope and + resources. + """ + + present_metrics_lower: Set[str] = {s.lower() for s in present_metrics} + + def get_export() -> List[ExportMetricsServiceRequest]: + response: GetMetricsResponse = self.client.get_metrics(GetMetricsRequest()) + serialized_metrics: RepeatedScalarFieldContainer[bytes] = response.metrics + return list(map(ExportMetricsServiceRequest.FromString, serialized_metrics)) + + def wait_condition( + exported: List[ExportMetricsServiceRequest], current: List[ExportMetricsServiceRequest] + ) -> bool: + received_metrics: Set[str] = set() + for exported_metric in current: + for resource_metric in exported_metric.resource_metrics: + for scope_metric in resource_metric.scope_metrics: + for metric in scope_metric.metrics: + received_metrics.add(metric.name.lower()) + return 0 < len(exported) == len(current) and present_metrics_lower.issubset(received_metrics) + + exported_metrics: List[ExportMetricsServiceRequest] = _wait_for_content(get_export, wait_condition) + metrics: List[ResourceScopeMetric] = [] + for exported_metric in exported_metrics: + for resource_metric in exported_metric.resource_metrics: + for scope_metric in resource_metric.scope_metrics: + for metric in scope_metric.metrics: + metrics.append(ResourceScopeMetric(resource_metric, scope_metric, metric)) + return metrics + + +def _wait_for_content(get_export: Callable[[], List[T]], wait_condition: Callable[[List[T], List[T]], bool]) -> List[T]: + # Verify that there is no more data to be received + deadline: datetime = datetime.now() + _TIMEOUT_DELAY + exported: List[T] = [] + + while deadline > datetime.now(): + try: + current_exported: List[T] = get_export() + if wait_condition(exported, current_exported): + return current_exported + exported = current_exported + + sleep(_WAIT_INTERVAL_SEC) + # pylint: disable=broad-exception-caught + except Exception: + _logger.exception("Error while reading content") + + raise RuntimeError("Timeout waiting for content") diff --git a/contract-tests/images/mock-collector/mock_collector_metrics_service.py b/contract-tests/images/mock-collector/mock_collector_metrics_service.py new file mode 100644 index 00000000..b20f6811 --- /dev/null +++ b/contract-tests/images/mock-collector/mock_collector_metrics_service.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from queue import Queue +from typing import List + +from grpc import ServicerContext +from typing_extensions import override + +from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ( + ExportMetricsServiceRequest, + ExportMetricsServiceResponse, +) +from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2_grpc import MetricsServiceServicer + + +class MockCollectorMetricsService(MetricsServiceServicer): + _export_requests: Queue = Queue(maxsize=-1) + + def get_requests(self) -> List[ExportMetricsServiceRequest]: + with self._export_requests.mutex: + return list(self._export_requests.queue) + + def clear_requests(self) -> None: + with self._export_requests.mutex: + self._export_requests.queue.clear() + + @override + # pylint: disable=invalid-name + def Export(self, request: ExportMetricsServiceRequest, context: ServicerContext) -> ExportMetricsServiceResponse: + self._export_requests.put(request) + return ExportMetricsServiceResponse() diff --git a/contract-tests/images/mock-collector/mock_collector_server.py b/contract-tests/images/mock-collector/mock_collector_server.py new file mode 100644 index 00000000..21d60a1e --- /dev/null +++ b/contract-tests/images/mock-collector/mock_collector_server.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import atexit +from concurrent.futures import ThreadPoolExecutor + +from grpc import server +from mock_collector_metrics_service import MockCollectorMetricsService +from mock_collector_service import MockCollectorService +from mock_collector_service_pb2_grpc import add_MockCollectorServiceServicer_to_server +from mock_collector_trace_service import MockCollectorTraceService + +from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2_grpc import add_MetricsServiceServicer_to_server +from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import add_TraceServiceServicer_to_server + + +def main() -> None: + mock_collector_server: server = server(thread_pool=ThreadPoolExecutor(max_workers=10)) + mock_collector_server.add_insecure_port("0.0.0.0:4315") + + trace_collector: MockCollectorTraceService = MockCollectorTraceService() + metrics_collector: MockCollectorMetricsService = MockCollectorMetricsService() + mock_collector: MockCollectorService = MockCollectorService(trace_collector, metrics_collector) + + add_TraceServiceServicer_to_server(trace_collector, mock_collector_server) + add_MetricsServiceServicer_to_server(metrics_collector, mock_collector_server) + add_MockCollectorServiceServicer_to_server(mock_collector, mock_collector_server) + + mock_collector_server.start() + atexit.register(mock_collector_server.stop, None) + print("Ready") + mock_collector_server.wait_for_termination(None) + + +if __name__ == "__main__": + main() diff --git a/contract-tests/images/mock-collector/mock_collector_service.py b/contract-tests/images/mock-collector/mock_collector_service.py new file mode 100644 index 00000000..bc5fe0e9 --- /dev/null +++ b/contract-tests/images/mock-collector/mock_collector_service.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import List + +from grpc import ServicerContext +from mock_collector_metrics_service import MockCollectorMetricsService +from mock_collector_service_pb2 import ( + ClearRequest, + ClearResponse, + GetMetricsRequest, + GetMetricsResponse, + GetTracesRequest, + GetTracesResponse, +) +from mock_collector_service_pb2_grpc import MockCollectorServiceServicer +from mock_collector_trace_service import MockCollectorTraceService +from typing_extensions import override + +from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest + + +class MockCollectorService(MockCollectorServiceServicer): + """Implements clear, get_traces, and get_metrics for the mock collector. + + Relies on metrics and trace collector services to collect the telemetry. + """ + + def __init__(self, trace_collector: MockCollectorTraceService, metrics_collector: MockCollectorMetricsService): + super().__init__() + self.trace_collector: MockCollectorTraceService = trace_collector + self.metrics_collector: MockCollectorMetricsService = metrics_collector + + @override + def clear(self, request: ClearRequest, context: ServicerContext) -> ClearResponse: + self.trace_collector.clear_requests() + self.metrics_collector.clear_requests() + return ClearResponse() + + @override + def get_traces(self, request: GetTracesRequest, context: ServicerContext) -> GetTracesResponse: + trace_requests: List[ExportTraceServiceRequest] = self.trace_collector.get_requests() + traces: List[bytes] = list(map(ExportTraceServiceRequest.SerializeToString, trace_requests)) + response: GetTracesResponse = GetTracesResponse(traces=traces) + return response + + @override + def get_metrics(self, request: GetMetricsRequest, context: ServicerContext) -> GetMetricsResponse: + metric_requests: List[ExportMetricsServiceRequest] = self.metrics_collector.get_requests() + metrics: List[bytes] = list(map(ExportTraceServiceRequest.SerializeToString, metric_requests)) + response: GetMetricsResponse = GetMetricsResponse(metrics=metrics) + return response diff --git a/contract-tests/images/mock-collector/mock_collector_service_pb2.py b/contract-tests/images/mock-collector/mock_collector_service_pb2.py new file mode 100644 index 00000000..2e519e57 --- /dev/null +++ b/contract-tests/images/mock-collector/mock_collector_service_pb2.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: mock_collector_service.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1cmock_collector_service.proto\"\x0e\n\x0c\x43learRequest\"\x0f\n\rClearResponse\"\x12\n\x10GetTracesRequest\"#\n\x11GetTracesResponse\x12\x0e\n\x06traces\x18\x01 \x03(\x0c\"\x13\n\x11GetMetricsRequest\"%\n\x12GetMetricsResponse\x12\x0f\n\x07metrics\x18\x01 \x03(\x0c\x32\xb1\x01\n\x14MockCollectorService\x12(\n\x05\x63lear\x12\r.ClearRequest\x1a\x0e.ClearResponse\"\x00\x12\x35\n\nget_traces\x12\x11.GetTracesRequest\x1a\x12.GetTracesResponse\"\x00\x12\x38\n\x0bget_metrics\x12\x12.GetMetricsRequest\x1a\x13.GetMetricsResponse\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'mock_collector_service_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_CLEARREQUEST']._serialized_start=32 + _globals['_CLEARREQUEST']._serialized_end=46 + _globals['_CLEARRESPONSE']._serialized_start=48 + _globals['_CLEARRESPONSE']._serialized_end=63 + _globals['_GETTRACESREQUEST']._serialized_start=65 + _globals['_GETTRACESREQUEST']._serialized_end=83 + _globals['_GETTRACESRESPONSE']._serialized_start=85 + _globals['_GETTRACESRESPONSE']._serialized_end=120 + _globals['_GETMETRICSREQUEST']._serialized_start=122 + _globals['_GETMETRICSREQUEST']._serialized_end=141 + _globals['_GETMETRICSRESPONSE']._serialized_start=143 + _globals['_GETMETRICSRESPONSE']._serialized_end=180 + _globals['_MOCKCOLLECTORSERVICE']._serialized_start=183 + _globals['_MOCKCOLLECTORSERVICE']._serialized_end=360 +# @@protoc_insertion_point(module_scope) diff --git a/contract-tests/images/mock-collector/mock_collector_service_pb2.pyi b/contract-tests/images/mock-collector/mock_collector_service_pb2.pyi new file mode 100644 index 00000000..76bd1d24 --- /dev/null +++ b/contract-tests/images/mock-collector/mock_collector_service_pb2.pyi @@ -0,0 +1,34 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class ClearRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class ClearResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class GetTracesRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class GetTracesResponse(_message.Message): + __slots__ = ("traces",) + TRACES_FIELD_NUMBER: _ClassVar[int] + traces: _containers.RepeatedScalarFieldContainer[bytes] + def __init__(self, traces: _Optional[_Iterable[bytes]] = ...) -> None: ... + +class GetMetricsRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class GetMetricsResponse(_message.Message): + __slots__ = ("metrics",) + METRICS_FIELD_NUMBER: _ClassVar[int] + metrics: _containers.RepeatedScalarFieldContainer[bytes] + def __init__(self, metrics: _Optional[_Iterable[bytes]] = ...) -> None: ... diff --git a/contract-tests/images/mock-collector/mock_collector_service_pb2_grpc.py b/contract-tests/images/mock-collector/mock_collector_service_pb2_grpc.py new file mode 100644 index 00000000..d25f8fa6 --- /dev/null +++ b/contract-tests/images/mock-collector/mock_collector_service_pb2_grpc.py @@ -0,0 +1,138 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import mock_collector_service_pb2 as mock__collector__service__pb2 + + +class MockCollectorServiceStub(object): + """Service definition for mock collector + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.clear = channel.unary_unary( + '/MockCollectorService/clear', + request_serializer=mock__collector__service__pb2.ClearRequest.SerializeToString, + response_deserializer=mock__collector__service__pb2.ClearResponse.FromString, + ) + self.get_traces = channel.unary_unary( + '/MockCollectorService/get_traces', + request_serializer=mock__collector__service__pb2.GetTracesRequest.SerializeToString, + response_deserializer=mock__collector__service__pb2.GetTracesResponse.FromString, + ) + self.get_metrics = channel.unary_unary( + '/MockCollectorService/get_metrics', + request_serializer=mock__collector__service__pb2.GetMetricsRequest.SerializeToString, + response_deserializer=mock__collector__service__pb2.GetMetricsResponse.FromString, + ) + + +class MockCollectorServiceServicer(object): + """Service definition for mock collector + """ + + def clear(self, request, context): + """Clears all traces and metrics captured by mock collector, so it can be used for multiple tests. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_traces(self, request, context): + """Returns traces exported to mock collector + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_metrics(self, request, context): + """Returns metrics exported to mock collector + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_MockCollectorServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'clear': grpc.unary_unary_rpc_method_handler( + servicer.clear, + request_deserializer=mock__collector__service__pb2.ClearRequest.FromString, + response_serializer=mock__collector__service__pb2.ClearResponse.SerializeToString, + ), + 'get_traces': grpc.unary_unary_rpc_method_handler( + servicer.get_traces, + request_deserializer=mock__collector__service__pb2.GetTracesRequest.FromString, + response_serializer=mock__collector__service__pb2.GetTracesResponse.SerializeToString, + ), + 'get_metrics': grpc.unary_unary_rpc_method_handler( + servicer.get_metrics, + request_deserializer=mock__collector__service__pb2.GetMetricsRequest.FromString, + response_serializer=mock__collector__service__pb2.GetMetricsResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'MockCollectorService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class MockCollectorService(object): + """Service definition for mock collector + """ + + @staticmethod + def clear(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/MockCollectorService/clear', + mock__collector__service__pb2.ClearRequest.SerializeToString, + mock__collector__service__pb2.ClearResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def get_traces(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/MockCollectorService/get_traces', + mock__collector__service__pb2.GetTracesRequest.SerializeToString, + mock__collector__service__pb2.GetTracesResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def get_metrics(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/MockCollectorService/get_metrics', + mock__collector__service__pb2.GetMetricsRequest.SerializeToString, + mock__collector__service__pb2.GetMetricsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/contract-tests/images/mock-collector/mock_collector_trace_service.py b/contract-tests/images/mock-collector/mock_collector_trace_service.py new file mode 100644 index 00000000..68ed9a03 --- /dev/null +++ b/contract-tests/images/mock-collector/mock_collector_trace_service.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from queue import Queue +from typing import List + +from grpc import ServicerContext +from typing_extensions import override + +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceRequest, + ExportTraceServiceResponse, +) +from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import TraceServiceServicer + + +class MockCollectorTraceService(TraceServiceServicer): + _export_requests: Queue = Queue(maxsize=-1) + + def get_requests(self) -> List[ExportTraceServiceRequest]: + with self._export_requests.mutex: + return list(self._export_requests.queue) + + def clear_requests(self) -> None: + with self._export_requests.mutex: + self._export_requests.queue.clear() + + @override + # pylint: disable=invalid-name + def Export(self, request: ExportTraceServiceRequest, context: ServicerContext) -> ExportTraceServiceResponse: + self._export_requests.put(request) + return ExportTraceServiceResponse() diff --git a/contract-tests/images/mock-collector/protos/mock_collector_service.proto b/contract-tests/images/mock-collector/protos/mock_collector_service.proto new file mode 100644 index 00000000..d7187c44 --- /dev/null +++ b/contract-tests/images/mock-collector/protos/mock_collector_service.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; + +// Service definition for mock collector +service MockCollectorService { + // Clears all traces and metrics captured by mock collector, so it can be used for multiple tests. + rpc clear (ClearRequest) returns (ClearResponse) {} + + // Returns traces exported to mock collector + rpc get_traces (GetTracesRequest) returns (GetTracesResponse) {} + + // Returns metrics exported to mock collector + rpc get_metrics (GetMetricsRequest) returns (GetMetricsResponse) {} +} + +// Empty request for clear rpc. +message ClearRequest {} + +// Empty response for clear rpc. +message ClearResponse {} + +// Empty request for get traces rpc. +message GetTracesRequest {} + +// Response for get traces rpc - all traces in byte form. +message GetTracesResponse{ + repeated bytes traces = 1; +} + +// Empty request for get metrics rpc. +message GetMetricsRequest {} + +// Response for get metrics rpc - all metrics in byte form. +message GetMetricsResponse { + repeated bytes metrics = 1; +} \ No newline at end of file diff --git a/contract-tests/images/mock-collector/pyproject.toml b/contract-tests/images/mock-collector/pyproject.toml new file mode 100644 index 00000000..76227044 --- /dev/null +++ b/contract-tests/images/mock-collector/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "mock-collector" +description = "Mock Collector used by contract tests for AWS OTEL Python Instrumentation" +version = "1.0.0" +license = "Apache-2.0" +requires-python = ">=3.8" + +dependencies = [ + "grpcio ~= 1.60.0", + "opentelemetry-proto==1.25.0", + "opentelemetry-sdk==1.25.0", + "protobuf==4.25.2", + "typing-extensions==4.9.0" +] + +[tool.hatch.build.targets.sdist] +include = ["*.py"] + +[tool.hatch.build.targets.wheel] +include = ["*.py"] diff --git a/contract-tests/images/mock-collector/requirements.txt b/contract-tests/images/mock-collector/requirements.txt new file mode 100644 index 00000000..e536e81f --- /dev/null +++ b/contract-tests/images/mock-collector/requirements.txt @@ -0,0 +1,5 @@ +grpcio==1.60.1 +opentelemetry-proto==1.25.0 +opentelemetry-sdk==1.25.0 +protobuf==4.25.2 +typing-extensions==4.9.0 \ No newline at end of file diff --git a/contract-tests/tests/pyproject.toml b/contract-tests/tests/pyproject.toml new file mode 100644 index 00000000..f648ab2a --- /dev/null +++ b/contract-tests/tests/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "contract-tests" +description = "Contract tests for AWS OTEL NodeJS Instrumentation" +version = "1.0.0" +license = "Apache-2.0" +requires-python = ">=3.8" + +dependencies = [ + "opentelemetry-proto==1.25.0", + "opentelemetry-sdk==1.25.0", + "testcontainers==3.7.1", + "grpcio==1.60.0", + "docker==7.1.0", + "mock-collector==1.0.0", + "requests==2.32.2" +] + +[project.optional-dependencies] +test = [] + +[tool.hatch.build.targets.sdist] +include = ["/test"] + +[tool.hatch.build.targets.wheel] +packages = ["test/amazon"] + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" +log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" \ No newline at end of file diff --git a/contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py b/contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py new file mode 100644 index 00000000..8773bb7c --- /dev/null +++ b/contract-tests/tests/test/amazon/aws-sdk/aws_sdk_test.py @@ -0,0 +1,585 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from logging import INFO, Logger, getLogger +from typing import Dict, List + +from docker.types import EndpointConfig +from mock_collector_client import ResourceScopeMetric, ResourceScopeSpan +from testcontainers.localstack import LocalStackContainer +from typing_extensions import override + +from amazon.base.contract_test_base import NETWORK_NAME, ContractTestBase +from amazon.utils.application_signals_constants import ( + AWS_LOCAL_OPERATION, + AWS_LOCAL_SERVICE, + AWS_REMOTE_OPERATION, + AWS_REMOTE_RESOURCE_IDENTIFIER, + AWS_REMOTE_RESOURCE_TYPE, + AWS_REMOTE_SERVICE, + AWS_SPAN_KIND, +) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue +from opentelemetry.proto.metrics.v1.metrics_pb2 import ExponentialHistogramDataPoint, Metric +from opentelemetry.proto.trace.v1.trace_pb2 import Span +from opentelemetry.semconv.trace import SpanAttributes + +_logger: Logger = getLogger(__name__) +_logger.setLevel(INFO) + +_AWS_SQS_QUEUE_URL: str = "aws.sqs.queue.url" +_AWS_SQS_QUEUE_NAME: str = "aws.sqs.queue.name" +_AWS_KINESIS_STREAM_NAME: str = "aws.kinesis.stream.name" + +# pylint: disable=too-many-public-methods +class AWSSDKTest(ContractTestBase): + _local_stack: LocalStackContainer + + def get_application_extra_environment_variables(self) -> Dict[str, str]: + return { + "AWS_SDK_S3_ENDPOINT": "http://s3.localstack:4566", + "AWS_SDK_ENDPOINT": "http://localstack:4566", + "AWS_REGION": "us-west-2", + } + + @override + def get_application_network_aliases(self) -> List[str]: + return ["error.test", "fault.test"] + + @override + @staticmethod + def get_application_image_name() -> str: + return "aws-application-signals-tests-aws-sdk-app" + + @classmethod + @override + def set_up_dependency_container(cls): + local_stack_networking_config: Dict[str, EndpointConfig] = { + NETWORK_NAME: EndpointConfig( + version="1.22", + aliases=[ + "localstack", + "s3.localstack", + ], + ) + } + cls._local_stack: LocalStackContainer = ( + LocalStackContainer(image="localstack/localstack:3.5.0") + .with_name("localstack") + .with_services("s3", "sqs", "dynamodb", "kinesis") + .with_env("DEFAULT_REGION", "us-west-2") + .with_kwargs(network=NETWORK_NAME, networking_config=local_stack_networking_config) + ) + cls._local_stack.start() + + @classmethod + @override + def tear_down_dependency_container(cls): + _logger.info("LocalStack stdout") + _logger.info(cls._local_stack.get_logs()[0].decode()) + _logger.info("LocalStack stderr") + _logger.info(cls._local_stack.get_logs()[1].decode()) + cls._local_stack.stop() + + def test_s3_create_bucket(self): + self.do_test_requests( + "s3/createbucket/create-bucket", + "GET", + 200, + 0, + 0, + local_operation="GET /s3", + remote_service="AWS::S3", + remote_operation="CreateBucket", + remote_resource_type="AWS::S3::Bucket", + remote_resource_identifier="test-bucket-name", + request_specific_attributes={ + SpanAttributes.AWS_S3_BUCKET: "test-bucket-name", + }, + span_name="S3.CreateBucket", + ) + + def test_s3_create_object(self): + self.do_test_requests( + "s3/createobject/put-object/some-object", + "GET", + 200, + 0, + 0, + local_operation="GET /s3", + remote_service="AWS::S3", + remote_operation="PutObject", + remote_resource_type="AWS::S3::Bucket", + remote_resource_identifier="test-put-object-bucket-name", + request_specific_attributes={ + SpanAttributes.AWS_S3_BUCKET: "test-put-object-bucket-name", + }, + span_name="S3.PutObject", + ) + + def test_s3_get_object(self): + self.do_test_requests( + "s3/getobject/get-object/some-object", + "GET", + 200, + 0, + 0, + local_operation="GET /s3", + remote_service="AWS::S3", + remote_operation="GetObject", + remote_resource_type="AWS::S3::Bucket", + remote_resource_identifier="test-get-object-bucket-name", + request_specific_attributes={ + SpanAttributes.AWS_S3_BUCKET: "test-get-object-bucket-name", + }, + span_name="S3.GetObject", + ) + + def test_s3_error(self): + self.do_test_requests( + "s3/error", + "GET", + 400, + 1, + 0, + local_operation="GET /s3", + remote_service="AWS::S3", + remote_operation="CreateBucket", + remote_resource_type="AWS::S3::Bucket", + remote_resource_identifier="-", + request_specific_attributes={ + SpanAttributes.AWS_S3_BUCKET: "-", + }, + span_name="S3.CreateBucket", + ) + + def test_s3_fault(self): + self.do_test_requests( + "s3/fault", + "GET", + 500, + 0, + 1, + dp_count=3, + local_operation="GET /s3", + local_operation_2="PUT /valid-bucket-name", + remote_service="AWS::S3", + remote_operation="CreateBucket", + remote_resource_type="AWS::S3::Bucket", + remote_resource_identifier="valid-bucket-name", + request_specific_attributes={ + SpanAttributes.AWS_S3_BUCKET: "valid-bucket-name", + }, + span_name="S3.CreateBucket", + ) + + def test_dynamodb_create_table(self): + self.do_test_requests( + "ddb/createtable/some-table", + "GET", + 200, + 0, + 0, + local_operation="GET /ddb", + remote_service="AWS::DynamoDB", + remote_operation="CreateTable", + remote_resource_type="AWS::DynamoDB::Table", + remote_resource_identifier="test_table", + request_specific_attributes={ + SpanAttributes.AWS_DYNAMODB_TABLE_NAMES: ["test_table"], + }, + span_name="DynamoDB.CreateTable", + ) + + def test_dynamodb_put_item(self): + self.do_test_requests( + "ddb/putitem/putitem-table/key", + "GET", + 200, + 0, + 0, + local_operation="GET /ddb", + remote_service="AWS::DynamoDB", + remote_operation="PutItem", + remote_resource_type="AWS::DynamoDB::Table", + remote_resource_identifier="put_test_table", + request_specific_attributes={ + SpanAttributes.AWS_DYNAMODB_TABLE_NAMES: ["put_test_table"], + }, + span_name="DynamoDB.PutItem", + ) + + def test_dynamodb_error(self): + self.do_test_requests( + "ddb/error", + "GET", + 400, + 1, + 0, + local_operation="GET /ddb", + remote_service="AWS::DynamoDB", + remote_operation="PutItem", + remote_resource_type="AWS::DynamoDB::Table", + remote_resource_identifier="invalid_table", + request_specific_attributes={ + SpanAttributes.AWS_DYNAMODB_TABLE_NAMES: ["invalid_table"], + }, + span_name="DynamoDB.PutItem", + ) + + def test_dynamodb_fault(self): + self.do_test_requests( + "ddb/fault", + "GET", + 500, + 0, + 1, + dp_count=3, + local_operation="GET /ddb", + local_operation_2="POST /", # for the fake ddb service + remote_service="AWS::DynamoDB", + remote_operation="PutItem", + remote_resource_type="AWS::DynamoDB::Table", + remote_resource_identifier="invalid_table", + request_specific_attributes={ + SpanAttributes.AWS_DYNAMODB_TABLE_NAMES: ["invalid_table"], + }, + span_name="DynamoDB.PutItem", + ) + + def test_sqs_create_queue(self): + self.do_test_requests( + "sqs/createqueue/some-queue", + "GET", + 200, + 0, + 0, + local_operation="GET /sqs", + remote_service="AWS::SQS", + remote_operation="CreateQueue", + remote_resource_type="AWS::SQS::Queue", + remote_resource_identifier="test_queue", + request_specific_attributes={ + _AWS_SQS_QUEUE_NAME: "test_queue", + }, + span_name="SQS.CreateQueue", + ) + + def test_sqs_send_message(self): + self.do_test_requests( + "sqs/publishqueue/some-queue", + "GET", + 200, + 0, + 0, + select_span_kind=Span.SPAN_KIND_PRODUCER, + local_operation="GET /sqs", + remote_service="AWS::SQS", + remote_operation="SendMessage", + remote_resource_type="AWS::SQS::Queue", + remote_resource_identifier="test_put_get_queue", + request_specific_attributes={ + _AWS_SQS_QUEUE_URL: "http://localstack:4566/000000000000/test_put_get_queue", + }, + span_name="test_put_get_queue send", # the span name is decided by upstream, but doesn't matter for app signals + dependency_metric_span_kind="PRODUCER", + ) + + def test_sqs_receive_message(self): + self.do_test_requests( + "sqs/consumequeue/some-queue", + "GET", + 200, + 0, + 0, + select_span_kind=Span.SPAN_KIND_CONSUMER, + local_operation="GET /sqs", + remote_service="AWS::SQS", + remote_operation="ReceiveMessage", + remote_resource_type="AWS::SQS::Queue", + remote_resource_identifier="test_put_get_queue", + request_specific_attributes={ + _AWS_SQS_QUEUE_URL: "http://localstack:4566/000000000000/test_put_get_queue", + }, + span_name="test_put_get_queue receive", # the span name is decided by upstream, but doesn't matter for app signals + dependency_metric_span_kind="CONSUMER", + ) + + def test_sqs_error(self): + self.do_test_requests( + "sqs/error", + "GET", + 400, + 1, + 0, + select_span_kind=Span.SPAN_KIND_PRODUCER, + local_operation="GET /sqs", + remote_service="AWS::SQS", + remote_operation="SendMessage", + remote_resource_type="AWS::SQS::Queue", + remote_resource_identifier="sqserror", + request_specific_attributes={ + _AWS_SQS_QUEUE_URL: "http://error.test:8080/000000000000/sqserror", + }, + span_name="sqserror send", # the span name is decided by upstream, but doesn't matter for app signals + dependency_metric_span_kind="PRODUCER", + ) + + def test_sqs_fault(self): + self.do_test_requests( + "sqs/fault", + "GET", + 500, + 0, + 1, + dp_count=3, + local_operation="GET /sqs", + local_operation_2="POST /", + remote_service="AWS::SQS", + remote_operation="CreateQueue", + remote_resource_type="AWS::SQS::Queue", + remote_resource_identifier="invalid_test", + request_specific_attributes={ + _AWS_SQS_QUEUE_NAME: "invalid_test", + }, + span_name="SQS.CreateQueue", + ) + + def test_kinesis_put_record(self): + self.do_test_requests( + "kinesis/putrecord/my-stream", + "GET", + 200, + 0, + 0, + local_operation="GET /kinesis", + remote_service="AWS::Kinesis", + remote_operation="PutRecord", + remote_resource_type="AWS::Kinesis::Stream", + remote_resource_identifier="test_stream", + request_specific_attributes={ + _AWS_KINESIS_STREAM_NAME: "test_stream", + }, + span_name="Kinesis.PutRecord", + ) + + def test_kinesis_error(self): + self.do_test_requests( + "kinesis/error", + "GET", + 400, + 1, + 0, + local_operation="GET /kinesis", + remote_service="AWS::Kinesis", + remote_operation="PutRecord", + remote_resource_type="AWS::Kinesis::Stream", + remote_resource_identifier="invalid_stream", + request_specific_attributes={ + _AWS_KINESIS_STREAM_NAME: "invalid_stream", + }, + span_name="Kinesis.PutRecord", + ) + + def test_kinesis_fault(self): + self.do_test_requests( + "kinesis/fault", + "GET", + 500, + 0, + 1, + local_operation="GET /kinesis", + local_operation_2="POST /", + dp_count=3, + remote_service="AWS::Kinesis", + remote_operation="PutRecord", + remote_resource_type="AWS::Kinesis::Stream", + remote_resource_identifier="test_stream", + request_specific_attributes={ + _AWS_KINESIS_STREAM_NAME: "test_stream", + }, + span_name="Kinesis.PutRecord", + ) + + @override + def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None: + target_spans: List[Span] = [] + for resource_scope_span in resource_scope_spans: + # pylint: disable=no-member + selected_span_kind = kwargs.get("select_span_kind") or Span.SPAN_KIND_CLIENT + if resource_scope_span.span.kind == selected_span_kind: + target_spans.append(resource_scope_span.span) + + if selected_span_kind == Span.SPAN_KIND_CLIENT: + span_kind = "CLIENT"; + elif selected_span_kind == Span.SPAN_KIND_PRODUCER: + span_kind = "PRODUCER"; + elif selected_span_kind == Span.SPAN_KIND_CONSUMER: + span_kind = "CONSUMER"; + + self.assertEqual(len(target_spans), 1) + self._assert_aws_attributes( + target_spans[0].attributes, + kwargs.get("local_operation"), + kwargs.get("remote_service"), + kwargs.get("remote_operation"), + span_kind, + kwargs.get("remote_resource_type", "None"), + kwargs.get("remote_resource_identifier", "None"), + ) + + def _assert_aws_attributes( + self, + attributes_list: List[KeyValue], + local_operation: str, + remote_service: str, + remote_operation: str, + span_kind: str, + remote_resource_type: str, + remote_resource_identifier: str, + ) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_OPERATION, local_operation) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_SERVICE, remote_service) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_OPERATION, remote_operation) + if remote_resource_type != "None": + self._assert_str_attribute(attributes_dict, AWS_REMOTE_RESOURCE_TYPE, remote_resource_type) + if remote_resource_identifier != "None": + self._assert_str_attribute(attributes_dict, AWS_REMOTE_RESOURCE_IDENTIFIER, remote_resource_identifier) + self._assert_str_attribute(attributes_dict, AWS_SPAN_KIND, span_kind) + + @override + def _assert_semantic_conventions_span_attributes( + self, resource_scope_spans: List[ResourceScopeSpan], method: str, path: str, status_code: int, **kwargs + ) -> None: + target_spans: List[Span] = [] + for resource_scope_span in resource_scope_spans: + # pylint: disable=no-member + selected_span_kind = kwargs.get("select_span_kind") or Span.SPAN_KIND_CLIENT + if resource_scope_span.span.kind == selected_span_kind: + target_spans.append(resource_scope_span.span) + + self.assertEqual(len(target_spans), 1) + self.assertEqual(target_spans[0].name, kwargs.get("span_name")) + self._assert_semantic_conventions_attributes( + target_spans[0].attributes, + kwargs.get("rpc_service") if "rpc_service" in kwargs else kwargs.get("remote_service").split("::")[-1], + kwargs.get("remote_operation"), + status_code, + kwargs.get("request_specific_attributes", {}), + ) + + # pylint: disable=unidiomatic-typecheck + def _assert_semantic_conventions_attributes( + self, + attributes_list: List[KeyValue], + service: str, + operation: str, + status_code: int, + request_specific_attributes: dict, + ) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + self._assert_str_attribute(attributes_dict, SpanAttributes.RPC_METHOD, operation) + self._assert_str_attribute(attributes_dict, SpanAttributes.RPC_SYSTEM, "aws-api") + self._assert_str_attribute(attributes_dict, SpanAttributes.RPC_SERVICE, service.split("::")[-1]) + self._assert_int_attribute(attributes_dict, SpanAttributes.HTTP_STATUS_CODE, status_code) + # TODO: aws sdk instrumentation is not respecting PEER_SERVICE + # self._assert_str_attribute(attributes_dict, SpanAttributes.PEER_SERVICE, "backend:8080") + for key, value in request_specific_attributes.items(): + if isinstance(value, str): + self._assert_str_attribute(attributes_dict, key, value) + elif isinstance(value, int): + self._assert_int_attribute(attributes_dict, key, value) + else: + self._assert_array_value_ddb_table_name(attributes_dict, key, value) + + @override + def _assert_metric_attributes( + self, + resource_scope_metrics: List[ResourceScopeMetric], + metric_name: str, + expected_sum: int, + **kwargs, + ) -> None: + target_metrics: List[Metric] = [] + for resource_scope_metric in resource_scope_metrics: + if resource_scope_metric.metric.name.lower() == metric_name.lower(): + target_metrics.append(resource_scope_metric.metric) + + self.assertEqual(len(target_metrics), 1) + target_metric: Metric = target_metrics[0] + dp_list: List[ExponentialHistogramDataPoint] = target_metric.exponential_histogram.data_points + dp_list_count: int = kwargs.get("dp_count", 2) + self.assertEqual(len(dp_list), dp_list_count) + + if (len(dp_list) == 2): + dependency_dp: ExponentialHistogramDataPoint = dp_list[0] + service_dp: ExponentialHistogramDataPoint = dp_list[1] + if len(dp_list[1].attributes) > len(dp_list[0].attributes): + dependency_dp = dp_list[1] + service_dp = dp_list[0] + attribute_dict: Dict[str, AnyValue] = self._get_attributes_dict(dependency_dp.attributes) + self._assert_str_attribute(attribute_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attribute_dict, AWS_LOCAL_OPERATION, kwargs.get("local_operation")) + self._assert_str_attribute(attribute_dict, AWS_REMOTE_SERVICE, kwargs.get("remote_service")) + self._assert_str_attribute(attribute_dict, AWS_REMOTE_OPERATION, kwargs.get("remote_operation")) + self._assert_str_attribute(attribute_dict, AWS_SPAN_KIND, kwargs.get("dependency_metric_span_kind") or "CLIENT") + remote_resource_type = kwargs.get("remote_resource_type", "None") + remote_resource_identifier = kwargs.get("remote_resource_identifier", "None") + if remote_resource_type != "None": + self._assert_str_attribute(attribute_dict, AWS_REMOTE_RESOURCE_TYPE, remote_resource_type) + if remote_resource_identifier != "None": + self._assert_str_attribute(attribute_dict, AWS_REMOTE_RESOURCE_IDENTIFIER, remote_resource_identifier) + self.check_sum(metric_name, dependency_dp.sum, expected_sum) + + attribute_dict: Dict[str, AnyValue] = self._get_attributes_dict(service_dp.attributes) + self._assert_str_attribute(attribute_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attribute_dict, AWS_LOCAL_OPERATION, kwargs.get("local_operation")) + self._assert_str_attribute(attribute_dict, AWS_SPAN_KIND, "LOCAL_ROOT") + self.check_sum(metric_name, service_dp.sum, expected_sum) + else: + dependency_dp: ExponentialHistogramDataPoint = max(dp_list, key=lambda dp: len(dp.attributes)) + # Assign the remaining two elements to dependency_dp and other_dp + remaining_dps = [dp for dp in dp_list if dp != dependency_dp] + service_dp, other_dp = remaining_dps[0], remaining_dps[1] + + attribute_dict: Dict[str, AnyValue] = self._get_attributes_dict(dependency_dp.attributes) + self._assert_str_attribute(attribute_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attribute_dict, AWS_LOCAL_OPERATION, kwargs.get("local_operation")) + self._assert_str_attribute(attribute_dict, AWS_REMOTE_SERVICE, kwargs.get("remote_service")) + self._assert_str_attribute(attribute_dict, AWS_REMOTE_OPERATION, kwargs.get("remote_operation")) + self._assert_str_attribute(attribute_dict, AWS_SPAN_KIND, kwargs.get("dependency_metric_span_kind") or "CLIENT") + remote_resource_type = kwargs.get("remote_resource_type", "None") + remote_resource_identifier = kwargs.get("remote_resource_identifier", "None") + if remote_resource_type != "None": + self._assert_str_attribute(attribute_dict, AWS_REMOTE_RESOURCE_TYPE, remote_resource_type) + if remote_resource_identifier != "None": + self._assert_str_attribute(attribute_dict, AWS_REMOTE_RESOURCE_IDENTIFIER, remote_resource_identifier) + self.check_sum(metric_name, dependency_dp.sum, expected_sum) + + attribute_dict_service: Dict[str, AnyValue] = self._get_attributes_dict(service_dp.attributes) + attribute_dict_other: Dict[str, AnyValue] = self._get_attributes_dict(other_dp.attributes) + + # test AWS_LOCAL_OPERATION to be either kwargs.get("local_operation_2") or kwargs.get("local_operation") in service_dp and other_dp + if kwargs.get("local_operation") not in [attribute_dict_service.get(AWS_LOCAL_OPERATION)]: + self._assert_str_attribute(attribute_dict_service, AWS_LOCAL_OPERATION, kwargs.get("local_operation_2")) + self._assert_str_attribute(attribute_dict_other, AWS_LOCAL_OPERATION, kwargs.get("local_operation")) + else: + self._assert_str_attribute(attribute_dict_service, AWS_LOCAL_OPERATION, kwargs.get("local_operation")) + self._assert_str_attribute(attribute_dict_other, AWS_LOCAL_OPERATION, kwargs.get("local_operation_2")) + + self._assert_str_attribute(attribute_dict_service, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attribute_dict_service, AWS_SPAN_KIND, "LOCAL_ROOT") + self.check_sum(metric_name, service_dp.sum, expected_sum) + + self._assert_str_attribute(attribute_dict_other, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attribute_dict_other, AWS_SPAN_KIND, "LOCAL_ROOT") + self.check_sum(metric_name, other_dp.sum, expected_sum) + + # pylint: disable=consider-using-enumerate + def _assert_array_value_ddb_table_name(self, attributes_dict: Dict[str, AnyValue], key: str, expect_values: list): + self.assertIn(key, attributes_dict) + actual_values: [AnyValue] = attributes_dict[key].array_value + self.assertEqual(len(actual_values.values), len(expect_values)) + for index in range(len(actual_values.values)): + self.assertEqual(actual_values.values[index].string_value, expect_values[index]) diff --git a/contract-tests/tests/test/amazon/base/contract_test_base.py b/contract-tests/tests/test/amazon/base/contract_test_base.py new file mode 100644 index 00000000..af6557c7 --- /dev/null +++ b/contract-tests/tests/test/amazon/base/contract_test_base.py @@ -0,0 +1,282 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import time +from logging import INFO, Logger, getLogger +from typing import Dict, List +from unittest import TestCase + +from docker import DockerClient +from docker.models.networks import Network, NetworkCollection +from docker.types import EndpointConfig +from mock_collector_client import MockCollectorClient, ResourceScopeMetric, ResourceScopeSpan +from requests import Response, request +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs +from typing_extensions import override + +from amazon.utils.application_signals_constants import ERROR_METRIC, FAULT_METRIC, LATENCY_METRIC +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue + +NETWORK_NAME: str = "aws-application-signals-network" + +_logger: Logger = getLogger(__name__) +_logger.setLevel(INFO) +_MOCK_COLLECTOR_ALIAS: str = "collector" +_MOCK_COLLECTOR_NAME: str = "aws-application-signals-mock-collector-nodejs" +_MOCK_COLLECTOR_PORT: int = 4315 + +def any_value_to_string(any_value_instance): + field_name = any_value_instance.WhichOneof('value') + + if field_name == 'string_value': + # Already a string + return any_value_instance.string_value + + elif field_name == 'bool_value': + # Convert boolean to string + return str(any_value_instance.bool_value) + + elif field_name == 'int_value': + # Convert integer to string + return str(any_value_instance.int_value) + + elif field_name == 'double_value': + # Convert double to string + return str(any_value_instance.double_value) + + elif field_name == 'bytes_value': + # Attempt to decode bytes to string + try: + return any_value_instance.bytes_value.decode('utf-8') + except UnicodeDecodeError: + # Handle decoding error + return None + + elif field_name == 'array_value': + # Convert each element in the array to string + elements = [] + for item in any_value_instance.array_value.values: + item_str = any_value_to_string(item) + if item_str is not None: + elements.append(item_str) + else: + # Cannot convert an element; return None or handle accordingly + return None + return '[' + ', '.join(elements) + ']' + + elif field_name == 'kvlist_value': + # Convert each key-value pair to string + kv_pairs = [] + for kv in any_value_instance.kvlist_value.values: + key = kv.key + value_str = any_value_to_string(kv.value) + if value_str is not None: + kv_pairs.append(f'"{key}": {value_str}') + else: + # Cannot convert a value; return None or handle accordingly + return None + return '{' + ', '.join(kv_pairs) + '}' + + else: + # No field is set or unknown field; cannot convert + return None + +# pylint: disable=broad-exception-caught +class ContractTestBase(TestCase): + """Base class for implementing a contract test. + + This class will create all the boilerplate necessary to run a contract test. It will: 1.Create a mock collector + container that receives telemetry data of the application being tested. 2. Create an application container which + will be used to exercise the library under test. + + Several methods are provided that can be overridden to customize the test scenario. + """ + + application: DockerContainer + mock_collector: DockerContainer + mock_collector_client: MockCollectorClient + network: Network + + @classmethod + @override + def setUpClass(cls) -> None: + cls.addClassCleanup(cls.class_tear_down) + cls.network = NetworkCollection(client=DockerClient()).create(NETWORK_NAME) + mock_collector_networking_config: Dict[str, EndpointConfig] = { + NETWORK_NAME: EndpointConfig(version="1.22", aliases=[_MOCK_COLLECTOR_ALIAS]) + } + cls.mock_collector: DockerContainer = ( + DockerContainer(_MOCK_COLLECTOR_NAME) + .with_exposed_ports(_MOCK_COLLECTOR_PORT) + .with_name(_MOCK_COLLECTOR_NAME) + .with_kwargs(network=NETWORK_NAME, networking_config=mock_collector_networking_config) + ) + cls.mock_collector.start() + wait_for_logs(cls.mock_collector, "Ready", timeout=20) + cls.set_up_dependency_container() + + @classmethod + def class_tear_down(cls) -> None: + try: + cls.tear_down_dependency_container() + except Exception: + _logger.exception("Failed to tear down dependency container") + + try: + _logger.info("MockCollector stdout") + _logger.info(cls.mock_collector.get_logs()[0].decode()) + _logger.info("MockCollector stderr") + _logger.info(cls.mock_collector.get_logs()[1].decode()) + cls.mock_collector.stop() + except Exception: + _logger.exception("Failed to tear down mock collector") + + cls.network.remove() + + @override + def setUp(self) -> None: + self.addCleanup(self.tear_down) + application_networking_config: Dict[str, EndpointConfig] = { + NETWORK_NAME: EndpointConfig(version="1.22", aliases=self.get_application_network_aliases()) + } + self.application: DockerContainer = ( + DockerContainer(self.get_application_image_name()) + .with_exposed_ports(self.get_application_port()) + .with_env("OTEL_METRIC_EXPORT_INTERVAL", "1000") + .with_env("OTEL_AWS_APPLICATION_SIGNALS_ENABLED", "true") + .with_env("OTEL_METRICS_EXPORTER", "none") + .with_env("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + .with_env("OTEL_BSP_SCHEDULE_DELAY", "1") + .with_env("OTEL_AWS_APPLICATION_SIGNALS_EXPORTER_ENDPOINT", f"http://collector:{_MOCK_COLLECTOR_PORT}") + .with_env("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", f"http://collector:{_MOCK_COLLECTOR_PORT}") + .with_env("OTEL_RESOURCE_ATTRIBUTES", self.get_application_otel_resource_attributes()) + .with_env("OTEL_TRACES_SAMPLER", "always_on") + .with_kwargs(network=NETWORK_NAME, networking_config=application_networking_config) + .with_name(self.get_application_image_name()) + ) + + extra_env: Dict[str, str] = self.get_application_extra_environment_variables() + for key in extra_env: + self.application.with_env(key, extra_env.get(key)) + self.application.start() + wait_for_logs(self.application, self.get_application_wait_pattern(), timeout=20) + self.mock_collector_client: MockCollectorClient = MockCollectorClient( + self.mock_collector.get_container_host_ip(), self.mock_collector.get_exposed_port(_MOCK_COLLECTOR_PORT) + ) + # Sleep for 3s to ensure any startup metrics have been exported + time.sleep(3) + # Clear all start up metrics, so tests are only testing telemetry generated by their invocations. + self.mock_collector_client.clear_signals() + + def tear_down(self) -> None: + try: + _logger.info("Application stdout") + _logger.info(self.application.get_logs()[0].decode()) + _logger.info("Application stderr") + _logger.info(self.application.get_logs()[1].decode()) + self.application.stop() + except Exception: + _logger.exception("Failed to tear down application") + + self.mock_collector_client.clear_signals() + + def do_test_requests( + self, path: str, method: str, status_code: int, expected_error: int, expected_fault: int, **kwargs + ) -> None: + response: Response = self.send_request(method, path) + self.assertEqual(status_code, response.status_code) + + resource_scope_spans: List[ResourceScopeSpan] = self.mock_collector_client.get_traces() + self._assert_aws_span_attributes(resource_scope_spans, path, **kwargs) + self._assert_semantic_conventions_span_attributes(resource_scope_spans, method, path, status_code, **kwargs) + + metrics: List[ResourceScopeMetric] = self.mock_collector_client.get_metrics( + {LATENCY_METRIC, ERROR_METRIC, FAULT_METRIC} + ) + self._assert_metric_attributes(metrics, LATENCY_METRIC, 5000, **kwargs) + self._assert_metric_attributes(metrics, ERROR_METRIC, expected_error, **kwargs) + self._assert_metric_attributes(metrics, FAULT_METRIC, expected_fault, **kwargs) + + def send_request(self, method, path) -> Response: + address: str = self.application.get_container_host_ip() + port: str = self.application.get_exposed_port(self.get_application_port()) + url: str = f"http://{address}:{port}/{path}" + _logger.info("send request to url: " + url) + return request(method, url, timeout=20) + + def _get_attributes_dict(self, attributes_list: List[KeyValue]) -> Dict[str, AnyValue]: + # _logger.info("Get the attributes dictionary ==============") + attributes_dict: Dict[str, AnyValue] = {} + for attribute in attributes_list: + key: str = attribute.key + value: AnyValue = attribute.value + # _logger.info("key: " + key + " value: " + any_value_to_string(value)) + + if key in attributes_dict: + old_value: AnyValue = attributes_dict[key] + self.fail(f"Attribute {key} unexpectedly duplicated. Value 1: {old_value} Value 2: {value}") + attributes_dict[key] = value + return attributes_dict + + def _assert_str_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, expected_value: str): + self.assertIn(key, attributes_dict) + actual_value: AnyValue = attributes_dict[key] + self.assertIsNotNone(actual_value) + self.assertEqual(expected_value, actual_value.string_value) + + def _assert_int_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, expected_value: int) -> None: + self.assertIn(key, attributes_dict) + actual_value: AnyValue = attributes_dict[key] + self.assertIsNotNone(actual_value) + self.assertEqual(expected_value, actual_value.int_value) + + def check_sum(self, metric_name: str, actual_sum: float, expected_sum: float) -> None: + if metric_name is LATENCY_METRIC: + self.assertTrue(0 < actual_sum < expected_sum) + else: + self.assertEqual(actual_sum, expected_sum) + + # pylint: disable=no-self-use + # Methods that should be overridden in subclasses + @classmethod + def set_up_dependency_container(cls): + return + + @classmethod + def tear_down_dependency_container(cls): + return + + def get_application_port(self) -> int: + return 8080 + + def get_application_extra_environment_variables(self) -> Dict[str, str]: + return {} + + def get_application_network_aliases(self) -> List[str]: + return [] + + @staticmethod + def get_application_image_name() -> str: + return None + + def get_application_wait_pattern(self) -> str: + return "Ready" + + def get_application_otel_service_name(self) -> str: + return self.get_application_image_name() + + def get_application_otel_resource_attributes(self) -> str: + return "service.name=" + self.get_application_otel_service_name() + + def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs): + self.fail("Tests must implement this function") + + def _assert_semantic_conventions_span_attributes( + self, resource_scope_spans: List[ResourceScopeSpan], method: str, path: str, status_code: int, **kwargs + ): + self.fail("Tests must implement this function") + + def _assert_metric_attributes( + self, resource_scope_metrics: List[ResourceScopeMetric], metric_name: str, expected_sum: int, **kwargs + ): + self.fail("Tests must implement this function") diff --git a/contract-tests/tests/test/amazon/base/database_contract_test_base.py b/contract-tests/tests/test/amazon/base/database_contract_test_base.py new file mode 100644 index 00000000..6613f125 --- /dev/null +++ b/contract-tests/tests/test/amazon/base/database_contract_test_base.py @@ -0,0 +1,178 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, List + +from mock_collector_client import ResourceScopeMetric, ResourceScopeSpan +from typing_extensions import override + +from amazon.base.contract_test_base import ContractTestBase +from amazon.utils.application_signals_constants import ( + AWS_LOCAL_OPERATION, + AWS_LOCAL_SERVICE, + AWS_REMOTE_DB_USER, + AWS_REMOTE_OPERATION, + AWS_REMOTE_RESOURCE_IDENTIFIER, + AWS_REMOTE_RESOURCE_TYPE, + AWS_REMOTE_SERVICE, + AWS_SPAN_KIND, +) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue +from opentelemetry.proto.metrics.v1.metrics_pb2 import ExponentialHistogramDataPoint, Metric +from opentelemetry.proto.trace.v1.trace_pb2 import Span +from opentelemetry.trace import StatusCode + +DATABASE_HOST: str = "mydb" +DATABASE_NAME: str = "testdb" +DATABASE_PASSWORD: str = "example" +DATABASE_USER: str = "root" +SPAN_KIND_CLIENT: str = "CLIENT" +SPAN_KIND_LOCAL_ROOT: str = "LOCAL_ROOT" + + +class DatabaseContractTestBase(ContractTestBase): + @staticmethod + def get_remote_service() -> str: + return None + + @staticmethod + def get_database_port() -> int: + return None + + def get_remote_resource_identifier(self) -> str: + return f"{DATABASE_NAME}|{DATABASE_HOST}|{self.get_database_port()}" + + @override + def get_application_extra_environment_variables(self) -> Dict[str, str]: + return { + "DB_HOST": DATABASE_HOST, + "DB_USER": DATABASE_USER, + "DB_PASS": DATABASE_PASSWORD, + "DB_NAME": DATABASE_NAME, + } + + # define tests for SQL database + def assert_drop_table_succeeds(self) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("drop_table", "GET", 200, 0, 0, sql_command="DROP TABLE", local_operation="GET /drop_table", span_name="DROP") + + def assert_create_database_succeeds(self) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("create_database", "GET", 200, 0, 0, sql_command="CREATE DATABASE", local_operation="GET /create_database", span_name="CREATE") + + def assert_select_succeeds(self) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("select", "GET", 200, 0, 0, sql_command="SELECT", local_operation="GET /select", span_name="SELECT") + + def assert_fault(self) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("fault", "GET", 500, 0, 1, sql_command="SELECT DISTINCT", local_operation="GET /fault", span_name="SELECT") + + # define tests for MongoDB database + def assert_delete_document_succeeds(self, **kwargs) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("delete_document", "GET", 200, 0, 0, **kwargs) + + def assert_insert_document_succeeds(self, **kwargs) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("insert_document", "GET", 200, 0, 0, **kwargs) + + def assert_update_document_succeeds(self, **kwargs) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("update_document", "GET", 200, 0, 0, **kwargs) + + def assert_find_document_succeeds(self, **kwargs) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("find", "GET", 200, 0, 0, **kwargs) + + def assert_fault_non_sql(self, **kwargs) -> None: + self.mock_collector_client.clear_signals() + self.do_test_requests("fault", "GET", 500, 0, 1, **kwargs) + + @override + def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None: + target_spans: List[Span] = [] + for resource_scope_span in resource_scope_spans: + # pylint: disable=no-member + if resource_scope_span.span.kind == Span.SPAN_KIND_CLIENT: + target_spans.append(resource_scope_span.span) + + self.assertEqual( + len(target_spans), 1, f"target_spans is {str(target_spans)}, although only one walue was expected" + ) + self._assert_aws_attributes(target_spans[0].attributes, **kwargs) + + @override + def _assert_semantic_conventions_span_attributes( + self, resource_scope_spans: List[ResourceScopeSpan], method: str, path: str, status_code: int, **kwargs + ) -> None: + target_spans: List[Span] = [] + for resource_scope_span in resource_scope_spans: + # pylint: disable=no-member + if resource_scope_span.span.kind == Span.SPAN_KIND_CLIENT: + target_spans.append(resource_scope_span.span) + + self.assertEqual(target_spans[0].name, kwargs.get("span_name")) + if status_code == 200: + self.assertEqual(target_spans[0].status.code, StatusCode.UNSET.value) + else: + self.assertEqual(target_spans[0].status.code, StatusCode.ERROR.value) + + self._assert_semantic_conventions_attributes(target_spans[0].attributes, **kwargs) + + def _assert_semantic_conventions_attributes(self, attributes_list: List[KeyValue], **kwargs) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + command = kwargs.get("db_operation") or kwargs.get("sql_command") + self.assertTrue(attributes_dict.get("db.statement").string_value.startswith(command)) + self._assert_str_attribute(attributes_dict, "db.system", self.get_remote_service()) + self._assert_str_attribute(attributes_dict, "db.name", DATABASE_NAME) + self._assert_str_attribute(attributes_dict, "net.peer.name", DATABASE_HOST) + self._assert_int_attribute(attributes_dict, "net.peer.port", self.get_database_port()) + self.assertTrue("server.address" not in attributes_dict) + self.assertTrue("server.port" not in attributes_dict) + self.assertTrue("db.operation" not in attributes_dict) + + @override + def _assert_aws_attributes( + self, attributes_list: List[KeyValue], expected_span_kind: str = SPAN_KIND_CLIENT, **kwargs + ) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_OPERATION, kwargs.get("local_operation")) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_SERVICE, self.get_remote_service()) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_OPERATION, kwargs.get("sql_command")) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_RESOURCE_TYPE, "DB::Connection") + self._assert_str_attribute(attributes_dict, AWS_REMOTE_DB_USER, DATABASE_USER) + self._assert_str_attribute( + attributes_dict, AWS_REMOTE_RESOURCE_IDENTIFIER, self.get_remote_resource_identifier() + ) + self._assert_str_attribute(attributes_dict, AWS_SPAN_KIND, expected_span_kind) + + @override + def _assert_metric_attributes( + self, resource_scope_metrics: List[ResourceScopeMetric], metric_name: str, expected_sum: int, **kwargs + ) -> None: + target_metrics: List[Metric] = [] + for resource_scope_metric in resource_scope_metrics: + if resource_scope_metric.metric.name.lower() == metric_name.lower(): + target_metrics.append(resource_scope_metric.metric) + self.assertLessEqual( + len(target_metrics), + 2, + f"target_metrics is {str(target_metrics)}, although we expect less than or equal to 2 metrics", + ) + dp_list: List[ExponentialHistogramDataPoint] = [ + dp for target_metric in target_metrics for dp in target_metric.exponential_histogram.data_points + ] + self.assertEqual(len(dp_list), 2) + dependency_dp: ExponentialHistogramDataPoint = dp_list[0] + service_dp: ExponentialHistogramDataPoint = dp_list[1] + if len(dp_list[1].attributes) > len(dp_list[0].attributes): + dependency_dp = dp_list[1] + service_dp = dp_list[0] + self._assert_aws_attributes(dependency_dp.attributes, SPAN_KIND_CLIENT, **kwargs) + self.check_sum(metric_name, dependency_dp.sum, expected_sum) + + attribute_dict: Dict[str, AnyValue] = self._get_attributes_dict(service_dp.attributes) + self._assert_str_attribute(attribute_dict, AWS_LOCAL_OPERATION,kwargs.get("local_operation")) + self._assert_str_attribute(attribute_dict, AWS_SPAN_KIND, SPAN_KIND_LOCAL_ROOT) + self.check_sum(metric_name, service_dp.sum, expected_sum) diff --git a/contract-tests/tests/test/amazon/http/http_test.py b/contract-tests/tests/test/amazon/http/http_test.py new file mode 100644 index 00000000..372bc2ae --- /dev/null +++ b/contract-tests/tests/test/amazon/http/http_test.py @@ -0,0 +1,163 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, List +from logging import INFO, Logger, getLogger + +from mock_collector_client import ResourceScopeMetric, ResourceScopeSpan +from typing_extensions import override + +from amazon.base.contract_test_base import ContractTestBase +from amazon.utils.application_signals_constants import ( + AWS_LOCAL_OPERATION, + AWS_LOCAL_SERVICE, + AWS_REMOTE_OPERATION, + AWS_REMOTE_SERVICE, + AWS_SPAN_KIND, +) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue +from opentelemetry.proto.metrics.v1.metrics_pb2 import ExponentialHistogramDataPoint, Metric +from opentelemetry.proto.trace.v1.trace_pb2 import Span +from opentelemetry.semconv.trace import SpanAttributes + +_logger: Logger = getLogger(__name__) +_logger.setLevel(INFO) + +class HttpTest(ContractTestBase): + @override + @staticmethod + def get_application_image_name() -> str: + return "aws-application-signals-tests-http-app" + + @override + def get_application_network_aliases(self) -> List[str]: + """ + This will be the target hostname of the clients making http requests in the application image, so that they + don't use localhost. + """ + return ["backend"] + + @override + def get_application_extra_environment_variables(self) -> Dict[str, str]: + """ + This does not appear to do anything, as it does not seem that OTEL supports peer service for Python. Keeping + for consistency with Java contract tests at this time. + """ + return {"OTEL_INSTRUMENTATION_COMMON_PEER_SERVICE_MAPPING": "backend=backend:8080"} + + def test_success(self) -> None: + self.do_test_requests("success", "GET", 200, 0, 0, request_method="GET", path_suffix="/success") + + def test_error(self) -> None: + self.do_test_requests("error", "GET", 400, 1, 0, request_method="GET", path_suffix="/error") + + def test_fault(self) -> None: + self.do_test_requests("fault", "GET", 500, 0, 1, request_method="GET", path_suffix="/fault") + + def test_success_post(self) -> None: + self.do_test_requests("success/postmethod", "POST", 200, 0, 0, request_method="POST", path_suffix="/success") + + def test_error_post(self) -> None: + self.do_test_requests("error/postmethod", "POST", 400, 1, 0, request_method="POST", path_suffix="/error") + + def test_fault_post(self) -> None: + self.do_test_requests("fault/postmethod", "POST", 500, 0, 1, request_method="POST", path_suffix="/fault") + + @override + def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None: + target_spans: List[Span] = [] + for resource_scope_span in resource_scope_spans: + # pylint: disable=no-member + if resource_scope_span.span.kind == Span.SPAN_KIND_CLIENT: + target_spans.append(resource_scope_span.span) + + self.assertEqual(len(target_spans), 1) + # _logger.info(target_spans[0].attributes) + self._assert_aws_attributes(target_spans[0].attributes, kwargs.get("request_method"), kwargs.get("path_suffix")) + + def _assert_aws_attributes(self, attributes_list: List[KeyValue], method: str, endpoint: str) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_OPERATION, f"{method} {endpoint}") + self._assert_str_attribute(attributes_dict, AWS_REMOTE_SERVICE, "backend:8080") + self._assert_str_attribute(attributes_dict, AWS_REMOTE_OPERATION, f"{method} /backend") + self._assert_str_attribute(attributes_dict, AWS_SPAN_KIND, "CLIENT") + + @override + def _assert_semantic_conventions_span_attributes( + self, resource_scope_spans: List[ResourceScopeSpan], method: str, path: str, status_code: int, **kwargs + ) -> None: + target_spans: List[Span] = [] + for resource_scope_span in resource_scope_spans: + # pylint: disable=no-member + if resource_scope_span.span.kind == Span.SPAN_KIND_CLIENT: + target_spans.append(resource_scope_span.span) + + self.assertEqual(len(target_spans), 1) + self.assertEqual(target_spans[0].name, method) + self._assert_semantic_conventions_attributes(target_spans[0].attributes, method, path, status_code) + + def _assert_semantic_conventions_attributes( + self, attributes_list: List[KeyValue], method: str, endpoint: str, status_code: int + ) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + self._assert_str_attribute(attributes_dict, SpanAttributes.NET_PEER_NAME, "backend") + self._assert_int_attribute(attributes_dict, SpanAttributes.NET_PEER_PORT, 8080) + self._assert_int_attribute(attributes_dict, SpanAttributes.HTTP_STATUS_CODE, status_code) + self._assert_str_attribute(attributes_dict, SpanAttributes.HTTP_URL, f"http://backend:8080/backend/{endpoint}") + self._assert_str_attribute(attributes_dict, SpanAttributes.HTTP_METHOD, method) + # http instrumentation is not respecting PEER_SERVICE + # self._assert_str_attribute(attributes_dict, SpanAttributes.PEER_SERVICE, "backend:8080") + + @override + def _assert_metric_attributes( + self, + resource_scope_metrics: List[ResourceScopeMetric], + metric_name: str, + expected_sum: int, + **kwargs, + ) -> None: + target_metrics: List[Metric] = [] + for resource_scope_metric in resource_scope_metrics: + if resource_scope_metric.metric.name.lower() == metric_name.lower(): + target_metrics.append(resource_scope_metric.metric) + + self.assertEqual(len(target_metrics), 1) + target_metric: Metric = target_metrics[0] + dp_list: List[ExponentialHistogramDataPoint] = target_metric.exponential_histogram.data_points + + + self.assertEqual(len(dp_list), 3) + + # Find the data point with the longest attributes list and assign it to service_dp + dependency_dp: ExponentialHistogramDataPoint = max(dp_list, key=lambda dp: len(dp.attributes)) + # Assign the remaining two elements to dependency_dp and other_dp + remaining_dps = [dp for dp in dp_list if dp != dependency_dp] + service_dp, other_dp = remaining_dps[0], remaining_dps[1] + + attribute_dict: Dict[str, AnyValue] = self._get_attributes_dict(dependency_dp.attributes) + method: str = kwargs.get("request_method") + path_suffix: str = kwargs.get("path_suffix") + self._assert_str_attribute(attribute_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attribute_dict, AWS_LOCAL_OPERATION, f"{method} {path_suffix}") + self._assert_str_attribute(attribute_dict, AWS_REMOTE_SERVICE, "backend:8080") + self._assert_str_attribute(attribute_dict, AWS_REMOTE_OPERATION, f"{method} /backend") + self._assert_str_attribute(attribute_dict, AWS_SPAN_KIND, "CLIENT") + self.check_sum(metric_name, dependency_dp.sum, expected_sum) + + attribute_dict_service: Dict[str, AnyValue] = self._get_attributes_dict(service_dp.attributes) + attribute_dict_other: Dict[str, AnyValue] = self._get_attributes_dict(other_dp.attributes) + + # test AWS_LOCAL_OPERATION to be either "/backend" or "/success" in service_dp and other_dp + if f"{method} /backend" not in [attribute_dict_service.get(AWS_LOCAL_OPERATION), attribute_dict_other.get(AWS_LOCAL_OPERATION)]: + self._assert_str_attribute(attribute_dict_service, AWS_LOCAL_OPERATION, f"{method} /backend") + self._assert_str_attribute(attribute_dict_other, AWS_LOCAL_OPERATION, f"{method} {path_suffix}") + else: + self._assert_str_attribute(attribute_dict_service, AWS_LOCAL_OPERATION, f"{method} {path_suffix}") + self._assert_str_attribute(attribute_dict_other, AWS_LOCAL_OPERATION, f"{method} /backend") + + # Check additional attributes for service_dp + self._assert_str_attribute(attribute_dict_service, AWS_SPAN_KIND, "LOCAL_ROOT") + self.check_sum(metric_name, service_dp.sum, expected_sum) + + self._assert_str_attribute(attribute_dict_other, AWS_SPAN_KIND, "LOCAL_ROOT") + self.check_sum(metric_name, other_dp.sum, expected_sum) diff --git a/contract-tests/tests/test/amazon/misc/configuration_test.py b/contract-tests/tests/test/amazon/misc/configuration_test.py new file mode 100644 index 00000000..449b3e73 --- /dev/null +++ b/contract-tests/tests/test/amazon/misc/configuration_test.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import time +from typing import Dict, List + +from mock_collector_client import ResourceScopeMetric, ResourceScopeSpan +from requests import Response, request +from typing_extensions import override + +from amazon.base.contract_test_base import ContractTestBase +from amazon.utils.application_signals_constants import ERROR_METRIC, FAULT_METRIC, LATENCY_METRIC +from opentelemetry.sdk.metrics.export import AggregationTemporality + +# Tests in this class are supposed to validate that the SDK was configured in the correct way: It +# uses the X-Ray ID format. Metrics are deltaPreferred. Type of the metrics are exponentialHistogram + + +class ConfigurationTest(ContractTestBase): + @override + @staticmethod + def get_application_image_name() -> str: + return "aws-application-signals-tests-http-app" + + @override + def get_application_network_aliases(self) -> List[str]: + """ + This will be the target hostname of the clients making http requests in the application image, so that they + don't use localhost. + """ + return ["backend"] + + def test_configuration_metrics(self): + address: str = self.application.get_container_host_ip() + port: str = self.application.get_exposed_port(self.get_application_port()) + url: str = f"http://{address}:{port}/success" + response: Response = request("GET", url, timeout=20) + self.assertEqual(200, response.status_code) + metrics: List[ResourceScopeMetric] = self.mock_collector_client.get_metrics( + {LATENCY_METRIC, ERROR_METRIC, FAULT_METRIC} + ) + + self.assertEqual(len(metrics), 3) + for metric in metrics: + self.assertIsNotNone(metric.metric.exponential_histogram) + self.assertEqual(metric.metric.exponential_histogram.aggregation_temporality, AggregationTemporality.DELTA) + self.mock_collector_client.clear_signals() + + def test_xray_id_format(self): + """ + We are testing here that the X-Ray id format is always used by inspecting the traceid that + was in the span received by the collector, which should be consistent across multiple spans. + We are testing the following properties: + 1. Traceid is random + 2. First 32 bits of traceid is a timestamp + It is important to remember that the X-Ray traceId format had to be adapted to fit into the + definition of the OpenTelemetry traceid: + https://opentelemetry.io/docs/specs/otel/trace/api/#retrieving-the-traceid-and-spanid + Specifically for an X-Ray traceid to be a valid Otel traceId, the version digit had to be + dropped. Reference: + https://github.com/open-telemetry/opentelemetry-python-contrib/blob/main/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/aws_xray_id_generator.py + """ + + seen: List[str] = [] + for _ in range(100): + address: str = self.application.get_container_host_ip() + port: str = self.application.get_exposed_port(self.get_application_port()) + url: str = f"http://{address}:{port}/success" + response: Response = request("GET", url, timeout=20) + self.assertEqual(200, response.status_code) + + # Since we just made the request, the time in epoch registered in the traceid should be + # approximate equal to the current time in the test, since both run on the same host. + start_time_sec: int = int(time.time()) + + resource_scope_spans: List[ResourceScopeSpan] = self.mock_collector_client.get_traces() + target_span: ResourceScopeSpan = resource_scope_spans[0] + + self.assertTrue(target_span.span.trace_id.hex() not in seen) + seen.append(target_span.span.trace_id.hex()) + + # trace_id is bytes, so we convert it to hex string and pick the first 8 byte + # that represent the timestamp, then convert it to int for timestamp in second + trace_id_time_stamp_int: int = int(target_span.span.trace_id.hex()[:8], 16) + + # Give 2 minutes time range of tolerance for the trace timestamp + self.assertGreater(trace_id_time_stamp_int, start_time_sec - 60) + self.assertGreater(start_time_sec + 60, trace_id_time_stamp_int) + self.mock_collector_client.clear_signals() diff --git a/contract-tests/tests/test/amazon/misc/resource_attributes_test_base.py b/contract-tests/tests/test/amazon/misc/resource_attributes_test_base.py new file mode 100644 index 00000000..66e5d36c --- /dev/null +++ b/contract-tests/tests/test/amazon/misc/resource_attributes_test_base.py @@ -0,0 +1,79 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, List + +from mock_collector_client import ResourceScopeMetric, ResourceScopeSpan +from requests import Response, request +from typing_extensions import override + +from amazon.base.contract_test_base import ContractTestBase +from amazon.utils.application_signals_constants import ERROR_METRIC, FAULT_METRIC, LATENCY_METRIC +from opentelemetry.proto.common.v1.common_pb2 import AnyValue +from opentelemetry.proto.metrics.v1.metrics_pb2 import Metric +from opentelemetry.proto.trace.v1.trace_pb2 import Span + + +def _get_k8s_attributes(): + return { + "k8s.namespace.name": "namespace-name", + "k8s.pod.name": "pod-name", + "k8s.deployment.name": "deployment-name", + } + + +# Tests consuming this class are supposed to validate that the agent is able to get the resource +# attributes through the environment variables OTEL_RESOURCE_ATTRIBUTES and OTEL_SERVICE_NAME +# +# These tests are structured with nested classes since it is only possible to change the +# resource attributes during the initialization of the OpenTelemetry SDK. + + +class ResourceAttributesTest(ContractTestBase): + @override + @staticmethod + def get_application_image_name() -> str: + return "aws-application-signals-tests-http-app" + + @override + def get_application_network_aliases(self) -> List[str]: + """ + This will be the target hostname of the clients making http requests in the application image, so that they + don't use localhost. + """ + return ["backend"] + + def do_test_resource_attributes(self, service_name): + address: str = self.application.get_container_host_ip() + port: str = self.application.get_exposed_port(self.get_application_port()) + url: str = f"http://{address}:{port}/success" + response: Response = request("GET", url, timeout=20) + self.assertEqual(200, response.status_code) + self.assert_resource_attributes(service_name) + + def assert_resource_attributes(self, service_name): + resource_scope_spans: List[ResourceScopeSpan] = self.mock_collector_client.get_traces() + metrics: List[ResourceScopeMetric] = self.mock_collector_client.get_metrics( + {LATENCY_METRIC, ERROR_METRIC, FAULT_METRIC} + ) + target_spans: List[Span] = [] + for resource_scope_span in resource_scope_spans: + # pylint: disable=no-member + if resource_scope_span.span.name == "tcp.connect": + target_spans.append(resource_scope_span.resource_spans) + + self.assertEqual(len(target_spans), 1) + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(target_spans[0].resource.attributes) + for key, value in _get_k8s_attributes().items(): + self._assert_str_attribute(attributes_dict, key, value) + self._assert_str_attribute(attributes_dict, "service.name", service_name) + + target_metrics: List[Metric] = [] + for resource_scope_metric in metrics: + if resource_scope_metric.metric.name in ["Error", "Fault", "Latency"]: + target_metrics.append(resource_scope_metric.resource_metrics) + self.assertEqual(len(target_metrics), 3) + for target_metric in target_metrics: + metric_attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(target_metric.resource.attributes) + for key, value in _get_k8s_attributes().items(): + self._assert_str_attribute(metric_attributes_dict, key, value) + self._assert_str_attribute(metric_attributes_dict, "service.name", service_name) diff --git a/contract-tests/tests/test/amazon/misc/service_name_in_env_var_test.py b/contract-tests/tests/test/amazon/misc/service_name_in_env_var_test.py new file mode 100644 index 00000000..719855b3 --- /dev/null +++ b/contract-tests/tests/test/amazon/misc/service_name_in_env_var_test.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import List + +from resource_attributes_test_base import ResourceAttributesTest, _get_k8s_attributes +from typing_extensions import override + + +class ServiceNameInEnvVarTest(ResourceAttributesTest): + @override + # pylint: disable=no-self-use + def get_application_extra_environment_variables(self) -> str: + return {"OTEL_SERVICE_NAME": "service-name-test"} + + @override + # pylint: disable=no-self-use + def get_application_otel_resource_attributes(self) -> str: + pairlist: List[str] = [] + for key, value in _get_k8s_attributes().items(): + pairlist.append(key + "=" + value) + return ",".join(pairlist) + + def test_service(self) -> None: + self.do_test_resource_attributes("service-name-test") diff --git a/contract-tests/tests/test/amazon/misc/service_name_in_resource_attributes_test.py b/contract-tests/tests/test/amazon/misc/service_name_in_resource_attributes_test.py new file mode 100644 index 00000000..923c2cfa --- /dev/null +++ b/contract-tests/tests/test/amazon/misc/service_name_in_resource_attributes_test.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import List + +from resource_attributes_test_base import ResourceAttributesTest, _get_k8s_attributes +from typing_extensions import override + + +class ServiceNameInResourceAttributesTest(ResourceAttributesTest): + @override + # pylint: disable=no-self-use + def get_application_otel_resource_attributes(self) -> str: + pairlist: List[str] = [] + for key, value in _get_k8s_attributes().items(): + pairlist.append(key + "=" + value) + pairlist.append("service.name=service-name") + return ",".join(pairlist) + + def test_service(self) -> None: + self.do_test_resource_attributes("service-name") diff --git a/contract-tests/tests/test/amazon/misc/unknown_service_name_test.py b/contract-tests/tests/test/amazon/misc/unknown_service_name_test.py new file mode 100644 index 00000000..a5e4be1a --- /dev/null +++ b/contract-tests/tests/test/amazon/misc/unknown_service_name_test.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import List + +from resource_attributes_test_base import ResourceAttributesTest, _get_k8s_attributes +from typing_extensions import override + + +class UnknownServiceNameTest(ResourceAttributesTest): + @override + # pylint: disable=no-self-use + def get_application_otel_resource_attributes(self) -> str: + pairlist: List[str] = [] + for key, value in _get_k8s_attributes().items(): + pairlist.append(key + "=" + value) + return ",".join(pairlist) + + def test_service(self) -> None: + # See https://github.com/aws-observability/aws-otel-js-instrumentation/blob/cec7306366a29ebb87cd303cb820abfe50cd5e30/aws-distro-opentelemetry-node-autoinstrumentation/src/aws-metric-attribute-generator.ts#L62-L66 + self.do_test_resource_attributes("unknown_service:node") diff --git a/contract-tests/tests/test/amazon/mongodb/mongodb_test.py b/contract-tests/tests/test/amazon/mongodb/mongodb_test.py new file mode 100644 index 00000000..d6164f7a --- /dev/null +++ b/contract-tests/tests/test/amazon/mongodb/mongodb_test.py @@ -0,0 +1,112 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, List + +# from testcontainers.mysql import MySqlContainer +# from testcontainers.mongodb import MongoDbContainer +from testcontainers.core.container import DockerContainer +from typing_extensions import override + +from amazon.base.contract_test_base import NETWORK_NAME +from amazon.base.database_contract_test_base import ( + DATABASE_HOST, + DATABASE_PASSWORD, + DATABASE_USER, + SPAN_KIND_CLIENT, + DatabaseContractTestBase, +) +from amazon.utils.application_signals_constants import ( + AWS_LOCAL_OPERATION, + AWS_LOCAL_SERVICE, + AWS_REMOTE_OPERATION, + AWS_REMOTE_RESOURCE_IDENTIFIER, + AWS_REMOTE_RESOURCE_TYPE, + AWS_REMOTE_SERVICE, + AWS_SPAN_KIND, +) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue + + +class MongodbTest(DatabaseContractTestBase): + @override + @classmethod + def set_up_dependency_container(cls) -> None: + cls.container = ( + DockerContainer("mongo:7.0.9") + .with_env("MONGO_INITDB_ROOT_USERNAME", DATABASE_USER) + .with_env("MONGO_INITDB_ROOT_PASSWORD", DATABASE_PASSWORD) + .with_kwargs(network=NETWORK_NAME) + .with_name(DATABASE_HOST) + ) + cls.container.start() + + @override + @classmethod + def tear_down_dependency_container(cls) -> None: + cls.container.stop() + + @override + @staticmethod + def get_remote_service() -> str: + return "mongodb" + + @override + @staticmethod + def get_database_port() -> int: + return 27017 + + @override + @staticmethod + def get_application_image_name() -> str: + return "aws-application-signals-tests-mongodb-app" + + def test_find_document_succeeds(self) -> None: + self.assert_find_document_succeeds(local_operation='GET /find', span_name='mongodb.find', db_operation='find', db_statement='statement') + + def test_delete_document_succeeds(self) -> None: + self.assert_delete_document_succeeds(local_operation='GET /delete_document', span_name='mongodb.delete', db_operation='delete') + + def test_insert_document_succeeds(self) -> None: + self.assert_insert_document_succeeds(local_operation='GET /insert_document', span_name='mongodb.insert', db_operation='insert') + + def test_update_document_succeeds(self) -> None: + # We don't know why "db.mongodb.collection" is set to "$cmd". It's probably a bug in upstream. + self.assert_update_document_succeeds(local_operation='GET /update_document', span_name='mongodb.findAndModify', db_operation='findAndModify', mongodb_collection='$cmd') + + + def test_fault(self) -> None: + # We don't know why "db.mongodb.collection" is set to "$cmd". It's probably a bug in upstream. + self.assert_fault_non_sql(local_operation='GET /fault', span_name='mongodb.invalidCommand', db_operation='invalidCommand', mongodb_collection='$cmd') + + @override + def _assert_aws_attributes( + self, attributes_list: List[KeyValue], expected_span_kind: str = SPAN_KIND_CLIENT, **kwargs + ) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_OPERATION, kwargs.get("local_operation")) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_SERVICE, self.get_remote_service()) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_OPERATION, kwargs.get("db_operation")) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_RESOURCE_TYPE, "DB::Connection") + # We might need to revisit the assertation here + # Currently the value is 'testdb|172.31.0.3|27017' not the expected one 'testdb|mydb|27017' + # self._assert_str_attribute( + # attributes_dict, AWS_REMOTE_RESOURCE_IDENTIFIER, self.get_remote_resource_identifier() + # ) + self._assert_str_attribute(attributes_dict, AWS_SPAN_KIND, expected_span_kind) + + @override + def _assert_semantic_conventions_attributes(self, attributes_list: List[KeyValue], **kwargs) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + self._assert_str_attribute(attributes_dict, "db.mongodb.collection", kwargs.get("mongodb_collection") or "employees") + self._assert_str_attribute(attributes_dict, "db.system", self.get_remote_service()) + self._assert_str_attribute(attributes_dict, "db.name", "testdb") + # the net.peer.name is currently set to be an ip address like '192.168.208.3' + # self._assert_str_attribute(attributes_dict, "net.peer.name", "mydb") + self.assertTrue("net.peer.name" in attributes_dict) #just checking the existence + self._assert_int_attribute(attributes_dict, "net.peer.port", self.get_database_port()) + self._assert_str_attribute(attributes_dict, "db.operation", kwargs.get("db_operation")) + self.assertTrue("db.statement" in attributes_dict) #just checking the existence + self.assertTrue("db.user" not in attributes_dict) + self.assertTrue("server.address" not in attributes_dict) + self.assertTrue("server.port" not in attributes_dict) diff --git a/contract-tests/tests/test/amazon/mongoose/mongoose_test.py b/contract-tests/tests/test/amazon/mongoose/mongoose_test.py new file mode 100644 index 00000000..7b284fdf --- /dev/null +++ b/contract-tests/tests/test/amazon/mongoose/mongoose_test.py @@ -0,0 +1,103 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, List + +# from testcontainers.mysql import MySqlContainer +# from testcontainers.mongodb import MongoDbContainer +from testcontainers.core.container import DockerContainer +from typing_extensions import override + +from amazon.base.contract_test_base import NETWORK_NAME +from amazon.base.database_contract_test_base import ( + DATABASE_HOST, + DATABASE_PASSWORD, + DATABASE_USER, + SPAN_KIND_CLIENT, + DatabaseContractTestBase, +) +from amazon.utils.application_signals_constants import ( + AWS_LOCAL_OPERATION, + AWS_LOCAL_SERVICE, + AWS_REMOTE_OPERATION, + AWS_REMOTE_RESOURCE_IDENTIFIER, + AWS_REMOTE_RESOURCE_TYPE, + AWS_REMOTE_SERVICE, + AWS_SPAN_KIND, +) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue + + +class MongooseTest(DatabaseContractTestBase): + @override + @classmethod + def set_up_dependency_container(cls) -> None: + cls.container = ( + DockerContainer("mongo:7.0.9") + .with_env("MONGO_INITDB_ROOT_USERNAME", DATABASE_USER) + .with_env("MONGO_INITDB_ROOT_PASSWORD", DATABASE_PASSWORD) + .with_kwargs(network=NETWORK_NAME) + .with_name(DATABASE_HOST) + ) + cls.container.start() + + @override + @classmethod + def tear_down_dependency_container(cls) -> None: + cls.container.stop() + + @override + @staticmethod + def get_remote_service() -> str: + return "mongoose" + + @override + @staticmethod + def get_database_port() -> int: + return 27017 + + @override + @staticmethod + def get_application_image_name() -> str: + return "aws-application-signals-tests-mongoose-app" + + def test_find_document_succeeds(self) -> None: + self.assert_find_document_succeeds(local_operation='GET /find', span_name='mongoose.Employee.find', db_operation='find') + + def test_delete_document_succeeds(self) -> None: + self.assert_delete_document_succeeds(local_operation='GET /delete_document', span_name='mongoose.Employee.deleteOne', db_operation='deleteOne') + + def test_insert_document_succeeds(self) -> None: + self.assert_insert_document_succeeds(local_operation='GET /insert_document', span_name='mongoose.Employee.save', db_operation='save') + + def test_update_document_succeeds(self) -> None: + self.assert_update_document_succeeds(local_operation='GET /update_document', span_name='mongoose.Employee.findOneAndUpdate', db_operation='findOneAndUpdate') + + + @override + def _assert_aws_attributes( + self, attributes_list: List[KeyValue], expected_span_kind: str = SPAN_KIND_CLIENT, **kwargs + ) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name()) + self._assert_str_attribute(attributes_dict, AWS_LOCAL_OPERATION, kwargs.get("local_operation")) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_SERVICE, self.get_remote_service()) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_OPERATION, kwargs.get("db_operation")) + self._assert_str_attribute(attributes_dict, AWS_REMOTE_RESOURCE_TYPE, "DB::Connection") + self._assert_str_attribute( + attributes_dict, AWS_REMOTE_RESOURCE_IDENTIFIER, self.get_remote_resource_identifier() + ) + self._assert_str_attribute(attributes_dict, AWS_SPAN_KIND, expected_span_kind) + + @override + def _assert_semantic_conventions_attributes(self, attributes_list: List[KeyValue], **kwargs) -> None: + attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list) + self._assert_str_attribute(attributes_dict, "db.mongodb.collection", "employees") + self._assert_str_attribute(attributes_dict, "db.system", self.get_remote_service()) + self._assert_str_attribute(attributes_dict, "db.name", "testdb") + self._assert_str_attribute(attributes_dict, "net.peer.name", "mydb") + self._assert_int_attribute(attributes_dict, "net.peer.port", self.get_database_port()) + self._assert_str_attribute(attributes_dict, "db.operation", kwargs.get("db_operation")) + self.assertTrue("db.statement" not in attributes_dict) + self.assertTrue("db.user" not in attributes_dict) + self.assertTrue("server.address" not in attributes_dict) + self.assertTrue("server.port" not in attributes_dict) diff --git a/contract-tests/tests/test/amazon/mysql2/mysql2_test.py b/contract-tests/tests/test/amazon/mysql2/mysql2_test.py new file mode 100644 index 00000000..a6b892bb --- /dev/null +++ b/contract-tests/tests/test/amazon/mysql2/mysql2_test.py @@ -0,0 +1,60 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, List + +from testcontainers.mysql import MySqlContainer +from typing_extensions import override + +from amazon.base.contract_test_base import NETWORK_NAME +from amazon.base.database_contract_test_base import ( + DATABASE_HOST, + DATABASE_NAME, + DATABASE_PASSWORD, + DATABASE_USER, + DatabaseContractTestBase, +) + + +class Mysql2Test(DatabaseContractTestBase): + @override + @classmethod + def set_up_dependency_container(cls) -> None: + cls.container = ( + MySqlContainer(MYSQL_USER=DATABASE_USER, MYSQL_PASSWORD=DATABASE_PASSWORD, MYSQL_DATABASE=DATABASE_NAME) + .with_kwargs(network=NETWORK_NAME) + .with_name(DATABASE_HOST) + ) + cls.container.start() + + @override + @classmethod + def tear_down_dependency_container(cls) -> None: + cls.container.stop() + + @override + @staticmethod + def get_remote_service() -> str: + return "mysql" + + @override + @staticmethod + def get_database_port() -> int: + return 3306 + + @override + @staticmethod + def get_application_image_name() -> str: + return "aws-application-signals-tests-mysql2-app" + + def test_select_succeeds(self) -> None: + self.assert_select_succeeds() + + def test_drop_table_succeeds(self) -> None: + self.assert_drop_table_succeeds() + + def test_create_database_succeeds(self) -> None: + self.assert_create_database_succeeds() + + def test_fault(self) -> None: + self.assert_fault() + diff --git a/contract-tests/tests/test/amazon/utils/application_signals_constants.py b/contract-tests/tests/test/amazon/utils/application_signals_constants.py new file mode 100644 index 00000000..9f3a625a --- /dev/null +++ b/contract-tests/tests/test/amazon/utils/application_signals_constants.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Constants for attributes and metric names defined in Application Signals. +""" + +# Metric names +LATENCY_METRIC: str = "latency" +ERROR_METRIC: str = "error" +FAULT_METRIC: str = "fault" + +# Attribute names +AWS_LOCAL_SERVICE: str = "aws.local.service" +AWS_LOCAL_OPERATION: str = "aws.local.operation" +AWS_REMOTE_DB_USER: str = "aws.remote.db.user" +AWS_REMOTE_SERVICE: str = "aws.remote.service" +AWS_REMOTE_OPERATION: str = "aws.remote.operation" +AWS_REMOTE_RESOURCE_TYPE: str = "aws.remote.resource.type" +AWS_REMOTE_RESOURCE_IDENTIFIER: str = "aws.remote.resource.identifier" +AWS_SPAN_KIND: str = "aws.span.kind" diff --git a/scripts/build_and_install_distro.sh b/scripts/build_and_install_distro.sh index 5c2e8418..e5c474f0 100755 --- a/scripts/build_and_install_distro.sh +++ b/scripts/build_and_install_distro.sh @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Fail fast -set -e +set -ex # Check script is running in scripts current_path=`pwd` @@ -17,3 +17,10 @@ fi npm install npm run compile +cd aws-distro-opentelemetry-node-autoinstrumentation +npm pack +cd .. + +mkdir -p dist +mv aws-distro-opentelemetry-node-autoinstrumentation/aws-aws-distro-opentelemetry-node-autoinstrumentation-*.tgz dist/ + diff --git a/scripts/set-up-contract-tests.sh b/scripts/set-up-contract-tests.sh new file mode 100755 index 00000000..3950371b --- /dev/null +++ b/scripts/set-up-contract-tests.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Fail fast +set -e + +# Check script is running in contract-tests +current_path=`pwd` +current_dir="${current_path##*/}" +if [ "$current_dir" != "aws-otel-js-instrumentation" ]; then + echo "Please run from aws-otel-js-instrumentation dir" + exit +fi + +# Remove old whl files (excluding distro whl) +rm -rf dist/mock_collector* +rm -rf dist/contract_tests* + +# Install python dependency for contract-test +python3 -m pip install pytest +python3 -m pip install pymysql +python3 -m pip install cryptography +python3 -m pip install mysql-connector-python +python3 -m pip install build +python3 -m pip install pymongo + +# To be clear, install binary for psycopg2 have no negative influence on otel here +# since Otel-Instrumentation running in container that install psycopg2 from source +python3 -m pip install sqlalchemy psycopg2-binary + +# Create mock-collector image +cd contract-tests/images/mock-collector +docker build . -t aws-application-signals-mock-collector-nodejs +if [ $? = 1 ]; then + echo "Docker build for mock collector failed" + exit 1 +fi + +# Find and store aws_opentelemetry_distro whl file +cd ../../../dist +DISTRO=(aws-aws-distro-opentelemetry-node-autoinstrumentation-*.tgz) +if [ "$DISTRO" = "aws-aws-distro-opentelemetry-node-autoinstrumentation-*.tgz" ]; then + echo "Could not find aws_opentelemetry_distro tgz file in dist dir." + exit 1 +fi + +# Create application images +cd .. +for dir in contract-tests/images/applications/* +do + application="${dir##*/}" + docker build . --progress=plain --no-cache -t aws-application-signals-tests-${application}-app -f ${dir}/Dockerfile --build-arg="DISTRO=${DISTRO}" + if [ $? = 1 ]; then + echo "Docker build for ${application} application failed" + exit 1 + fi +done + +# Build and install mock-collector +cd contract-tests/images/mock-collector +python3 -m build --outdir ../../../dist +cd ../../../dist +python3 -m pip install mock_collector-1.0.0-py3-none-any.whl --force-reinstall + +# Build and install contract-tests +cd ../contract-tests/tests +python3 -m build --outdir ../../dist +cd ../../dist +# --force-reinstall causes `ERROR: No matching distribution found for mock-collector==1.0.0`, but uninstalling and reinstalling works pretty reliably. +python3 -m pip uninstall contract-tests -y +python3 -m pip install contract_tests-1.0.0-py3-none-any.whl