|
1 | 1 | import type { DB } from '@matrixai/db'; |
| 2 | +import type { ResourceAcquire } from '@matrixai/resources'; |
2 | 3 | import type { |
3 | 4 | ClientRPCRequestParams, |
4 | 5 | ClientRPCResponseResult, |
5 | | - SecretIdentifierMessage, |
| 6 | + SecretsRemoveHeaderMessage, |
| 7 | + SecretIdentifierMessageTagged, |
6 | 8 | SuccessOrErrorMessage, |
7 | 9 | } from '../types'; |
8 | 10 | import type VaultManager from '../../vaults/VaultManager'; |
| 11 | +import type { FileSystemWritable } from '../../vaults/types'; |
| 12 | +import { withG } from '@matrixai/resources'; |
9 | 13 | import { DuplexHandler } from '@matrixai/rpc'; |
10 | 14 | import * as vaultsUtils from '../../vaults/utils'; |
11 | 15 | import * as vaultsErrors from '../../vaults/errors'; |
| 16 | +import * as clientErrors from '../errors'; |
12 | 17 |
|
13 | 18 | class VaultsSecretsRemove extends DuplexHandler< |
14 | 19 | { |
15 | 20 | db: DB; |
16 | 21 | vaultManager: VaultManager; |
17 | 22 | }, |
18 | | - ClientRPCRequestParams<SecretIdentifierMessage>, |
| 23 | + ClientRPCRequestParams< |
| 24 | + SecretsRemoveHeaderMessage | SecretIdentifierMessageTagged |
| 25 | + >, |
19 | 26 | ClientRPCResponseResult<SuccessOrErrorMessage> |
20 | 27 | > { |
21 | 28 | public handle = async function* ( |
22 | | - input: AsyncIterable<ClientRPCRequestParams<SecretIdentifierMessage>>, |
| 29 | + input: AsyncIterableIterator< |
| 30 | + ClientRPCRequestParams< |
| 31 | + SecretsRemoveHeaderMessage | SecretIdentifierMessageTagged |
| 32 | + > |
| 33 | + >, |
23 | 34 | ): AsyncGenerator<ClientRPCResponseResult<SuccessOrErrorMessage>> { |
24 | 35 | const { db, vaultManager }: { db: DB; vaultManager: VaultManager } = |
25 | 36 | this.container; |
26 | | - // Create a record of secrets to be removed, grouped by vault names |
27 | | - const vaultGroups: Record<string, Array<string>> = {}; |
28 | | - const secretNames: Array<[string, string]> = []; |
29 | | - let metadata: any = undefined; |
30 | | - for await (const secretRemoveMessage of input) { |
31 | | - if (metadata == null) metadata = secretRemoveMessage.metadata ?? {}; |
32 | | - secretNames.push([ |
33 | | - secretRemoveMessage.nameOrId, |
34 | | - secretRemoveMessage.secretName, |
35 | | - ]); |
| 37 | + // Extract the header message from the iterator |
| 38 | + const headerMessagePair = await input.next(); |
| 39 | + const headerMessage: |
| 40 | + | SecretsRemoveHeaderMessage |
| 41 | + | SecretIdentifierMessageTagged = headerMessagePair.value; |
| 42 | + // Testing if the header is of the expected format |
| 43 | + if ( |
| 44 | + headerMessagePair.done || |
| 45 | + headerMessage.type !== 'VaultNamesHeaderMessage' |
| 46 | + ) { |
| 47 | + throw new clientErrors.ErrorClientInvalidHeader(); |
36 | 48 | } |
37 | | - secretNames.forEach(([vaultName, secretName]) => { |
38 | | - if (vaultGroups[vaultName] == null) { |
39 | | - vaultGroups[vaultName] = []; |
| 49 | + // Create an array of write acquires |
| 50 | + const vaultAcquires = await db.withTransactionF(async (tran) => { |
| 51 | + const vaultAcquires: Array<ResourceAcquire<FileSystemWritable>> = []; |
| 52 | + for (const vaultName of headerMessage.vaultNames) { |
| 53 | + const vaultIdFromName = await vaultManager.getVaultId(vaultName, tran); |
| 54 | + const vaultId = vaultIdFromName ?? vaultsUtils.decodeVaultId(vaultName); |
| 55 | + if (vaultId == null) { |
| 56 | + throw new vaultsErrors.ErrorVaultsVaultUndefined( |
| 57 | + `Vault ${vaultName} does not exist`, |
| 58 | + ); |
| 59 | + } |
| 60 | + const acquire = await vaultManager.withVaults( |
| 61 | + [vaultId], |
| 62 | + async (vault) => vault.acquireWrite(), |
| 63 | + ); |
| 64 | + vaultAcquires.push(acquire); |
40 | 65 | } |
41 | | - vaultGroups[vaultName].push(secretName); |
| 66 | + return vaultAcquires; |
42 | 67 | }); |
43 | | - // Now, all the paths will be removed for a vault within a single commit |
44 | | - yield* db.withTransactionG( |
45 | | - async function* (tran): AsyncGenerator<SuccessOrErrorMessage> { |
46 | | - for (const [vaultName, secretNames] of Object.entries(vaultGroups)) { |
47 | | - const vaultIdFromName = await vaultManager.getVaultId( |
48 | | - vaultName, |
49 | | - tran, |
50 | | - ); |
51 | | - const vaultId = |
52 | | - vaultIdFromName ?? vaultsUtils.decodeVaultId(vaultName); |
53 | | - if (vaultId == null) { |
54 | | - throw new vaultsErrors.ErrorVaultsVaultUndefined(); |
| 68 | + // Acquire all locks in parallel and perform all operations at once |
| 69 | + yield* withG( |
| 70 | + vaultAcquires, |
| 71 | + async function* (efses): AsyncGenerator<SuccessOrErrorMessage> { |
| 72 | + // Creating the vault name to efs map for easy access |
| 73 | + const vaultMap = new Map<string, FileSystemWritable>(); |
| 74 | + for (let i = 0; i < efses.length; i++) { |
| 75 | + vaultMap.set(headerMessage!.vaultNames[i], efses[i]); |
| 76 | + } |
| 77 | + let loopRan = false; |
| 78 | + for await (const message of input) { |
| 79 | + loopRan = true; |
| 80 | + // Header messages should not be seen anymore |
| 81 | + if (message.type === 'VaultNamesHeaderMessage') { |
| 82 | + throw new clientErrors.ErrorClientProtocolError( |
| 83 | + 'The header message cannot be sent multiple times', |
| 84 | + ); |
55 | 85 | } |
56 | | - yield* vaultManager.withVaultsG( |
57 | | - [vaultId], |
58 | | - async function* (vault): AsyncGenerator<SuccessOrErrorMessage> { |
59 | | - yield* vault.writeG( |
60 | | - async function* (efs): AsyncGenerator<SuccessOrErrorMessage> { |
61 | | - for (const secretName of secretNames) { |
62 | | - try { |
63 | | - const stat = await efs.stat(secretName); |
64 | | - if (stat.isDirectory()) { |
65 | | - await efs.rmdir(secretName, { |
66 | | - recursive: metadata?.options?.recursive, |
67 | | - }); |
68 | | - } else { |
69 | | - await efs.unlink(secretName); |
70 | | - } |
71 | | - yield { |
72 | | - type: 'success', |
73 | | - success: true, |
74 | | - }; |
75 | | - } catch (e) { |
76 | | - if ( |
77 | | - e.code === 'ENOENT' || |
78 | | - e.code === 'ENOTEMPTY' || |
79 | | - e.code === 'EINVAL' |
80 | | - ) { |
81 | | - // INVAL can be triggered if removing the root of the |
82 | | - // vault is attempted. |
83 | | - yield { |
84 | | - type: 'error', |
85 | | - code: e.code, |
86 | | - reason: secretName, |
87 | | - }; |
88 | | - } else { |
89 | | - throw e; |
90 | | - } |
91 | | - } |
92 | | - } |
93 | | - }, |
94 | | - ); |
95 | | - }, |
96 | | - tran, |
| 86 | + const efs = vaultMap.get(message.nameOrId); |
| 87 | + if (efs == null) { |
| 88 | + throw new vaultsErrors.ErrorVaultsVaultUndefined( |
| 89 | + `Vault ${message.nameOrId} was not present in the header message`, |
| 90 | + ); |
| 91 | + } |
| 92 | + try { |
| 93 | + const stat = await efs.stat(message.secretName); |
| 94 | + if (stat.isDirectory()) { |
| 95 | + await efs.rmdir(message.secretName, { |
| 96 | + recursive: headerMessage.recursive, |
| 97 | + }); |
| 98 | + } else { |
| 99 | + await efs.unlink(message.secretName); |
| 100 | + } |
| 101 | + yield { |
| 102 | + type: 'success', |
| 103 | + success: true, |
| 104 | + }; |
| 105 | + } catch (e) { |
| 106 | + if ( |
| 107 | + e.code === 'ENOENT' || |
| 108 | + e.code === 'ENOTEMPTY' || |
| 109 | + e.code === 'EINVAL' |
| 110 | + ) { |
| 111 | + // EINVAL can be triggered if removing the root of the |
| 112 | + // vault is attempted. |
| 113 | + yield { |
| 114 | + type: 'error', |
| 115 | + code: e.code, |
| 116 | + reason: message.secretName, |
| 117 | + }; |
| 118 | + } else { |
| 119 | + throw e; |
| 120 | + } |
| 121 | + } |
| 122 | + } |
| 123 | + // Content messages must follow header messages |
| 124 | + if (!loopRan) { |
| 125 | + throw new clientErrors.ErrorClientProtocolError( |
| 126 | + 'No content messages followed header message', |
97 | 127 | ); |
98 | 128 | } |
99 | 129 | }, |
|
0 commit comments