Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
import { describe, it, beforeEach, afterEach, expect, spyOn, mock } from 'bun:test';
import { FederationRequestService } from './federation-request.service';
import { FederationConfigService } from './federation-config.service';
import * as nacl from 'tweetnacl';
import * as discovery from '@hs/homeserver/src/helpers/server-discovery/discovery';
import * as authentication from '@hs/homeserver/src/authentication';
import * as signJson from '@hs/homeserver/src/signJson';
import * as url from '@hs/homeserver/src/helpers/url';

describe('FederationRequestService', () => {
let service: FederationRequestService;
let configService: FederationConfigService;
let originalFetch: typeof globalThis.fetch;

const mockServerName = 'example.com';
const mockSigningKey = 'aGVsbG93b3JsZA==';
const mockSigningKeyId = 'ed25519:1';

const mockKeyPair = {
publicKey: new Uint8Array([1, 2, 3]),
secretKey: new Uint8Array([4, 5, 6]),
};

const mockDiscoveryResult = {
address: 'target.example.com',
headers: {
'Host': 'target.example.com',
'X-Custom-Header': 'Test'
},
};

const mockSignature = new Uint8Array([7, 8, 9]);

const mockSignedJson = {
content: 'test',
signatures: {
'example.com': {
'ed25519:1': 'abcdef',
},
},
};

const mockAuthHeaders = 'X-Matrix origin="example.com",destination="target.example.com",key="ed25519:1",sig="xyz123"';

beforeEach(() => {
originalFetch = globalThis.fetch;

spyOn(nacl.sign.keyPair, 'fromSecretKey').mockReturnValue(mockKeyPair);
spyOn(nacl.sign, 'detached').mockReturnValue(mockSignature);

spyOn(discovery, 'resolveHostAddressByServerName').mockResolvedValue(mockDiscoveryResult);
spyOn(url, 'extractURIfromURL').mockReturnValue('/test/path?query=value');
spyOn(authentication, 'authorizationHeaders').mockResolvedValue(mockAuthHeaders);
spyOn(signJson, 'signJson').mockResolvedValue(mockSignedJson);
spyOn(authentication, 'computeAndMergeHash').mockImplementation((obj: any) => obj);

globalThis.fetch = Object.assign(
async (_url: string, _options?: RequestInit) => {
return {
ok: true,
status: 200,
json: async () => ({ result: 'success' }),
text: async () => '{"result":"success"}',
} as Response;
},
{ preconnect: () => { } }
) as typeof fetch;

configService = {
serverName: mockServerName,
signingKey: mockSigningKey,
signingKeyId: mockSigningKeyId,
} as FederationConfigService;

service = new FederationRequestService(configService);
});

afterEach(() => {
globalThis.fetch = originalFetch;
mock.restore();
});

describe('makeSignedRequest', () => {
it('should make a successful signed request without body', async () => {
const fetchSpy = spyOn(globalThis, 'fetch');

const result = await service.makeSignedRequest({
method: 'GET',
domain: 'target.example.com',
uri: '/test/path',
});

expect(configService.serverName).toBe(mockServerName);
expect(configService.signingKey).toBe(mockSigningKey);
expect(configService.signingKeyId).toBe(mockSigningKeyId);

expect(nacl.sign.keyPair.fromSecretKey).toHaveBeenCalled();

expect(discovery.resolveHostAddressByServerName).toHaveBeenCalledWith(
'target.example.com',
mockServerName
);

expect(fetchSpy).toHaveBeenCalledWith(
'https://target.example.com/test/path',
expect.objectContaining({
method: 'GET',
headers: expect.objectContaining({
Authorization: mockAuthHeaders,
'X-Custom-Header': 'Test',
}),
})
);

expect(result).toEqual({ result: 'success' });
});

it('should make a successful signed request with body', async () => {
const fetchSpy = spyOn(globalThis, 'fetch');

const mockBody = { key: 'value' };

const result = await service.makeSignedRequest({
method: 'POST',
domain: 'target.example.com',
uri: '/test/path',
body: mockBody,
});

expect(signJson.signJson).toHaveBeenCalledWith(
expect.objectContaining({ key: 'value', signatures: {} }),
expect.any(Object),
mockServerName
);

expect(authentication.authorizationHeaders).toHaveBeenCalledWith(
mockServerName,
expect.any(Object),
'target.example.com',
'POST',
'/test/path?query=value',
mockSignedJson
);

expect(fetchSpy).toHaveBeenCalledWith(
'https://target.example.com/test/path',
expect.objectContaining({
method: 'POST',
body: JSON.stringify(mockSignedJson),
})
);

expect(result).toEqual({ result: 'success' });
});

it('should make a signed request with query parameters', async () => {
const fetchSpy = spyOn(globalThis, 'fetch');

const result = await service.makeSignedRequest({
method: 'GET',
domain: 'target.example.com',
uri: '/test/path',
queryString: 'param1=value1&param2=value2',
});

expect(fetchSpy).toHaveBeenCalledWith(
'https://target.example.com/test/path?param1=value1&param2=value2',
expect.any(Object)
);

expect(result).toEqual({ result: 'success' });
});

it('should handle fetch errors properly', async () => {
globalThis.fetch = Object.assign(
async () => {
return {
ok: false,
status: 404,
text: async () => 'Not Found',
} as Response;
},
{ preconnect: () => { } }
) as typeof fetch;

try {
await service.makeSignedRequest({
method: 'GET',
domain: 'target.example.com',
uri: '/test/path',
});
} catch (error: unknown) {
if (error instanceof Error) {
expect(error.message).toContain('Federation request failed: 404 Not Found');
} else {
throw error;
}
}
});

it('should handle JSON error responses properly', async () => {
globalThis.fetch = Object.assign(
async () => {
return {
ok: false,
status: 400,
text: async () => '{"error":"Bad Request","code":"M_INVALID_PARAM"}',
} as Response;
},
{ preconnect: () => { } }
) as typeof fetch;

try {
await service.makeSignedRequest({
method: 'GET',
domain: 'target.example.com',
uri: '/test/path',
});
} catch (error: unknown) {
if (error instanceof Error) {
expect(error.message).toContain('Federation request failed: 400 {"error":"Bad Request","code":"M_INVALID_PARAM"}');
} else {
throw error;
}
}
});

it('should handle network errors properly', async () => {
globalThis.fetch = Object.assign(
async () => {
throw new Error('Network Error');
},
{ preconnect: () => { } }
) as typeof fetch;

try {
await service.makeSignedRequest({
method: 'GET',
domain: 'target.example.com',
uri: '/test/path',
});
} catch (error: unknown) {
if (error instanceof Error) {
expect(error.message).toBe('Network Error');
} else {
throw error;
}
}
});
});

describe('convenience methods', () => {
it('should call makeSignedRequest with correct parameters for GET', async () => {
const makeSignedRequestSpy = spyOn(service, 'makeSignedRequest').mockResolvedValue({ result: 'success' });

await service.get('target.example.com', '/api/resource', { filter: 'active' });

expect(makeSignedRequestSpy).toHaveBeenCalledWith({
method: 'GET',
domain: 'target.example.com',
uri: '/api/resource',
queryString: 'filter=active',
});
});

it('should call makeSignedRequest with correct parameters for POST', async () => {
const makeSignedRequestSpy = spyOn(service, 'makeSignedRequest').mockResolvedValue({ result: 'success' });

const body = { data: 'example' };
await service.post('target.example.com', '/api/resource', body, { version: '1' });

expect(makeSignedRequestSpy).toHaveBeenCalledWith({
method: 'POST',
domain: 'target.example.com',
uri: '/api/resource',
body,
queryString: 'version=1',
});
});

it('should call makeSignedRequest with correct parameters for PUT', async () => {
const makeSignedRequestSpy = spyOn(service, 'makeSignedRequest').mockResolvedValue({ result: 'success' });

const body = { data: 'updated' };
await service.put('target.example.com', '/api/resource/123', body);

expect(makeSignedRequestSpy).toHaveBeenCalledWith({
method: 'PUT',
domain: 'target.example.com',
uri: '/api/resource/123',
body,
queryString: '',
});
});
});
});
22 changes: 11 additions & 11 deletions packages/federation-sdk/src/services/federation-request.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import { Injectable, Logger } from '@nestjs/common';
import * as nacl from 'tweetnacl';
import { authorizationHeaders, computeAndMergeHash } from '../../../homeserver/src/authentication';
import { extractURIfromURL } from '../../../homeserver/src/helpers/url';
import { signJson } from '../../../homeserver/src/signJson';
import { EncryptionValidAlgorithm, signJson } from '../../../homeserver/src/signJson';
import { FederationConfigService } from './federation-config.service';

interface SignedRequest {
method: string;
domain: string;
uri: string;
body?: any;
body?: Record<string, unknown>;
queryString?: string;
}

Expand All @@ -37,8 +37,8 @@ export class FederationRequestService {
const privateKeyBytes = Buffer.from(signingKeyBase64, 'base64');
const keyPair = nacl.sign.keyPair.fromSecretKey(privateKeyBytes);

const signingKey = {
algorithm: 'ed25519',
const signingKey: SigningKey = {
algorithm: EncryptionValidAlgorithm.ed25519,
version: signingKeyId.split(':')[1] || '1',
privateKey: keyPair.secretKey,
publicKey: keyPair.publicKey,
Expand All @@ -57,27 +57,27 @@ export class FederationRequestService {

this.logger.debug(`Making ${method} request to ${url.toString()}`);

let signedBody: unknown;
let signedBody: Record<string, unknown> | undefined;
if (body) {
signedBody = await signJson(
computeAndMergeHash({ ...body, signatures: {} }),
signingKey as any,
signingKey,
serverName
);
}

const auth = await authorizationHeaders(
serverName,
signingKey as unknown as SigningKey,
signingKey,
domain,
method,
extractURIfromURL(url),
signedBody as any,
signedBody,
);

const response = await fetch(url.toString(), {
method,
...(signedBody && { body: JSON.stringify(signedBody) }) as any,
...(signedBody && { body: JSON.stringify(signedBody) }),
headers: {
Authorization: auth,
...discoveryHeaders,
Expand All @@ -93,14 +93,14 @@ export class FederationRequestService {
throw new Error(`Federation request failed: ${response.status} ${errorDetail}`);
}

return response.json() as Promise<T>;
return response.json();
} catch (error: any) {
this.logger.error(`Federation request failed: ${error.message}`, error.stack);
throw error;
}
}

async request<T>(method: HttpMethod, targetServer: string, endpoint: string, body?: any, queryParams?: Record<string, string>): Promise<T> {
async request<T>(method: HttpMethod, targetServer: string, endpoint: string, body?: Record<string, unknown>, queryParams?: Record<string, string>): Promise<T> {
let queryString = '';

if (queryParams) {
Expand Down
Loading