|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""GCP Cloud Batch helpers. |
| 15 | +
|
| 16 | +This module provides a client for interacting with the GCP Batch service. It is |
| 17 | +used to run granular tasks that require a high degree of isolation, such as |
| 18 | +executing untrusted code from fuzzing jobs. Each task is run in its own VM, |
| 19 | +ensuring that any potential security issues are contained. |
| 20 | +""" |
| 21 | +import threading |
| 22 | +from typing import List |
| 23 | +from typing import Tuple |
| 24 | +import uuid |
| 25 | + |
| 26 | +from google.cloud import batch_v1 as batch |
| 27 | + |
| 28 | +from clusterfuzz._internal.base import retry |
| 29 | +from clusterfuzz._internal.google_cloud_utils import credentials |
| 30 | +from clusterfuzz._internal.metrics import logs |
| 31 | +from clusterfuzz._internal.remote_task import RemoteTaskInterface |
| 32 | + |
| 33 | +_local = threading.local() |
| 34 | + |
| 35 | +DEFAULT_RETRY_COUNT = 0 |
| 36 | + |
| 37 | +# Controls how many containers (ClusterFuzz tasks) can run on a single VM. |
| 38 | +# THIS SHOULD BE 1 OR THERE WILL BE SECURITY PROBLEMS. |
| 39 | +TASK_COUNT_PER_NODE = 1 |
| 40 | + |
| 41 | +# See https://cloud.google.com/batch/quotas#job_limits |
| 42 | +MAX_CONCURRENT_VMS_PER_JOB = 1000 |
| 43 | + |
| 44 | + |
| 45 | +def _create_batch_client_new(): |
| 46 | + """Creates a batch client.""" |
| 47 | + creds, _ = credentials.get_default() |
| 48 | + return batch.BatchServiceClient(credentials=creds) |
| 49 | + |
| 50 | + |
| 51 | +def _batch_client(): |
| 52 | + """Gets the batch client, creating it if it does not exist.""" |
| 53 | + if hasattr(_local, 'client'): |
| 54 | + return _local.client |
| 55 | + |
| 56 | + _local.client = _create_batch_client_new() |
| 57 | + return _local.client |
| 58 | + |
| 59 | + |
| 60 | +def get_job_name(): |
| 61 | + return 'j-' + str(uuid.uuid4()).lower() |
| 62 | + |
| 63 | + |
| 64 | +def _get_task_spec(batch_workload_spec): |
| 65 | + """Gets the task spec based on the batch workload spec.""" |
| 66 | + runnable = batch.Runnable() |
| 67 | + runnable.container = batch.Runnable.Container() |
| 68 | + runnable.container.image_uri = batch_workload_spec.docker_image |
| 69 | + clusterfuzz_release = batch_workload_spec.clusterfuzz_release |
| 70 | + runnable.container.options = ( |
| 71 | + '--memory-swappiness=40 --shm-size=1.9g --rm --net=host ' |
| 72 | + '-e HOST_UID=1337 -P --privileged --cap-add=all ' |
| 73 | + f'-e CLUSTERFUZZ_RELEASE={clusterfuzz_release} ' |
| 74 | + '--name=clusterfuzz -e UNTRUSTED_WORKER=False -e UWORKER=True ' |
| 75 | + '-e USE_GCLOUD_STORAGE_RSYNC=1 ' |
| 76 | + '-e UWORKER_INPUT_DOWNLOAD_URL') |
| 77 | + runnable.container.volumes = ['/var/scratch0:/mnt/scratch0'] |
| 78 | + task_spec = batch.TaskSpec() |
| 79 | + task_spec.runnables = [runnable] |
| 80 | + if batch_workload_spec.retry: |
| 81 | + # Tasks in general have 6 hours to run (except pruning which has 24). |
| 82 | + # Our signed URLs last 24 hours. Therefore, the maxiumum number of retries |
| 83 | + # is 4. This is a temporary solution anyway. |
| 84 | + task_spec.max_retry_count = 4 |
| 85 | + else: |
| 86 | + task_spec.max_retry_count = DEFAULT_RETRY_COUNT |
| 87 | + task_spec.max_run_duration = batch_workload_spec.max_run_duration |
| 88 | + return task_spec |
| 89 | + |
| 90 | + |
| 91 | +def _set_preemptible(instance_policy, batch_workload_spec) -> None: |
| 92 | + if batch_workload_spec.preemptible: |
| 93 | + instance_policy.provisioning_model = ( |
| 94 | + batch.AllocationPolicy.ProvisioningModel.PREEMPTIBLE) |
| 95 | + else: |
| 96 | + instance_policy.provisioning_model = ( |
| 97 | + batch.AllocationPolicy.ProvisioningModel.STANDARD) |
| 98 | + |
| 99 | + |
| 100 | +def _get_allocation_policy(spec): |
| 101 | + """Returns the allocation policy for a BatchWorkloadSpec.""" |
| 102 | + disk = batch.AllocationPolicy.Disk() |
| 103 | + disk.image = 'batch-cos' |
| 104 | + disk.size_gb = spec.disk_size_gb |
| 105 | + disk.type = spec.disk_type |
| 106 | + instance_policy = batch.AllocationPolicy.InstancePolicy() |
| 107 | + instance_policy.boot_disk = disk |
| 108 | + instance_policy.machine_type = spec.machine_type |
| 109 | + _set_preemptible(instance_policy, spec) |
| 110 | + instances = batch.AllocationPolicy.InstancePolicyOrTemplate() |
| 111 | + instances.policy = instance_policy |
| 112 | + |
| 113 | + # Don't use external ip addresses which use quota, cost money, and are |
| 114 | + # unnecessary. |
| 115 | + network_interface = batch.AllocationPolicy.NetworkInterface() |
| 116 | + network_interface.no_external_ip_address = True |
| 117 | + network_interface.network = spec.network |
| 118 | + network_interface.subnetwork = spec.subnetwork |
| 119 | + network_interfaces = [network_interface] |
| 120 | + network_policy = batch.AllocationPolicy.NetworkPolicy() |
| 121 | + network_policy.network_interfaces = network_interfaces |
| 122 | + |
| 123 | + allocation_policy = batch.AllocationPolicy() |
| 124 | + allocation_policy.instances = [instances] |
| 125 | + allocation_policy.network = network_policy |
| 126 | + service_account = batch.ServiceAccount(email=spec.service_account_email) # pylint: disable=no-member |
| 127 | + allocation_policy.service_account = service_account |
| 128 | + return allocation_policy |
| 129 | + |
| 130 | + |
| 131 | +@retry.wrap( |
| 132 | + retries=3, |
| 133 | + delay=2, |
| 134 | + function='google_cloud_utils.batch._send_create_job_request') |
| 135 | +def _send_create_job_request(create_request): |
| 136 | + return _batch_client().create_job(create_request) |
| 137 | + |
| 138 | + |
| 139 | +def count_queued_or_scheduled_tasks(project: str, |
| 140 | + region: str) -> Tuple[int, int]: |
| 141 | + """Counts the number of queued and scheduled tasks.""" |
| 142 | + region = f'projects/{project}/locations/{region}' |
| 143 | + jobs_filter = 'Status.State="SCHEDULED" OR Status.State="QUEUED"' |
| 144 | + req = batch.types.ListJobsRequest(parent=region, filter=jobs_filter) |
| 145 | + queued = 0 |
| 146 | + scheduled = 0 |
| 147 | + for job in _batch_client().list_jobs(request=req): |
| 148 | + if job.status.state == batch.JobStatus.State.SCHEDULED: |
| 149 | + scheduled += job.task_groups[0].task_count |
| 150 | + elif job.status.state == batch.JobStatus.State.QUEUED: |
| 151 | + queued += job.task_groups[0].task_count |
| 152 | + return (queued, scheduled) |
| 153 | + |
| 154 | + |
| 155 | +class GcpBatchClient(RemoteTaskInterface): |
| 156 | + """A client for creating and managing jobs on the GCP Batch service. |
| 157 | + |
| 158 | + This client is responsible for translating ClusterFuzz task specifications |
| 159 | + into GCP Batch jobs. It handles the configuration of the job, including |
| 160 | + the machine type, disk size, and network settings, as well as the task |
| 161 | + specification, which defines the container image and command to run. |
| 162 | + """ |
| 163 | + |
| 164 | + def create_job(self, spec, input_urls: List[str]): |
| 165 | + """Creates and starts a batch job from |spec| that executes all tasks. |
| 166 | + |
| 167 | + This method creates a new GCP Batch job with a single task group. The |
| 168 | + task group is configured to run a containerized task for each of the |
| 169 | + input URLs. The tasks are run in parallel, with each task having its |
| 170 | + own VM, as defined by the TASK_COUNT_PER_NODE setting. |
| 171 | + """ |
| 172 | + task_group = batch.TaskGroup() |
| 173 | + task_group.task_count = len(input_urls) |
| 174 | + assert task_group.task_count < MAX_CONCURRENT_VMS_PER_JOB |
| 175 | + task_environments = [ |
| 176 | + batch.Environment(variables={'UWORKER_INPUT_DOWNLOAD_URL': input_url}) |
| 177 | + for input_url in input_urls |
| 178 | + ] |
| 179 | + task_group.task_environments = task_environments |
| 180 | + task_group.task_spec = _get_task_spec(spec) |
| 181 | + task_group.task_count_per_node = TASK_COUNT_PER_NODE |
| 182 | + assert task_group.task_count_per_node == 1, 'This is a security issue' |
| 183 | + |
| 184 | + job = batch.Job() |
| 185 | + job.task_groups = [task_group] |
| 186 | + job.allocation_policy = _get_allocation_policy(spec) |
| 187 | + job.logs_policy = batch.LogsPolicy() |
| 188 | + job.logs_policy.destination = batch.LogsPolicy.Destination.CLOUD_LOGGING |
| 189 | + job.priority = spec.priority |
| 190 | + |
| 191 | + create_request = batch.CreateJobRequest() |
| 192 | + create_request.job = job |
| 193 | + job_name = get_job_name() |
| 194 | + create_request.job_id = job_name |
| 195 | + # The job's parent is the region in which the job will run |
| 196 | + project_id = spec.project |
| 197 | + create_request.parent = f'projects/{project_id}/locations/{spec.gce_region}' |
| 198 | + job_result = _send_create_job_request(create_request) |
| 199 | + logs.info(f'Created batch job id={job_name}.', spec=spec) |
| 200 | + return job_result |
0 commit comments