|
| 1 | +import java.io.{File, IOException} |
| 2 | +import java.net.{HttpURLConnection, URL} |
| 3 | +import scala.concurrent.duration.* |
| 4 | + |
| 5 | +object UrlRetry { |
| 6 | + |
| 7 | + /** Set of transient HTTP status codes that are considered safe to retry: |
| 8 | + * - 408: Request Timeout – server timed out waiting for the request |
| 9 | + * - 425: Too Early – server unwilling to process a possibly replayed request |
| 10 | + * - 429: Too Many Requests – client is being rate-limited |
| 11 | + * - 500: Internal Server Error – generic server-side failure |
| 12 | + * - 502: Bad Gateway – upstream server returned an invalid response |
| 13 | + * - 503: Service Unavailable – server temporarily overloaded or down |
| 14 | + * - 504: Gateway Timeout – upstream server didn't respond in time |
| 15 | + */ |
| 16 | + private val TransientStatusCodes: Set[Int] = |
| 17 | + Set(408, 425, 429, 500, 502, 503, 504) |
| 18 | + |
| 19 | + /** Base type for HTTP-related failures. */ |
| 20 | + sealed abstract class HttpException(message: String) extends RuntimeException(message) |
| 21 | + |
| 22 | + /** Thrown when we got an HTTP response with a status code. */ |
| 23 | + final case class HttpStatusException(status: Int, url: URL, msg: String) |
| 24 | + extends HttpException(s"HTTP $status for $url: $msg") |
| 25 | + |
| 26 | + /** Marker exception used to trigger retries (for transient HTTP statuses). */ |
| 27 | + final case class TransientHttpException(status: Int, url: URL, msg: String) |
| 28 | + extends HttpException(s"Transient HTTP $status for $url: $msg") |
| 29 | + |
| 30 | + /** Retries a block that uses URL.openConnection (or any other network I/O) up to `maxRetries`. |
| 31 | + * |
| 32 | + * - Retries for transient HTTP status codes (defaults in TransientStatusCodes) and IOExceptions. |
| 33 | + * - Retry interval and backoff are configurable. |
| 34 | + * - Backoff is applied multiplicatively. |
| 35 | + * |
| 36 | + * Notes: |
| 37 | + * - The `block` should perform the full request and either: (a) throw on non-2xx (via `openAndCheck` below), or |
| 38 | + * (b) return some value you define after inspecting the response code. |
| 39 | + */ |
| 40 | + def withTransientHttpRetries[T](maxRetries: Int, baseInterval: FiniteDuration, backoffFactor: Double)( |
| 41 | + block: => T |
| 42 | + ): T = { |
| 43 | + require(maxRetries >= 1, "maxRetries must be >= 1") |
| 44 | + require(baseInterval.toMillis >= 0, "baseInterval must be >= 0") |
| 45 | + require(backoffFactor >= 1.0, "backoffFactor must be >= 1.0") |
| 46 | + |
| 47 | + def sleepForAttempt(attempt: Int): Unit = { |
| 48 | + // attempt is 1-based; attempt 1 means "first try" => no sleep before it |
| 49 | + if (attempt > 1) { |
| 50 | + val multiplier = math.pow(backoffFactor, attempt - 2) // before 2nd try => pow(...,0) |
| 51 | + val delayMs = (baseInterval.toMillis.toDouble * multiplier).toLong |
| 52 | + if (delayMs > 0) Thread.sleep(delayMs) |
| 53 | + } |
| 54 | + } |
| 55 | + |
| 56 | + def loop(attempt: Int, lastError: Throwable): T = { |
| 57 | + if (attempt > 1) println(s"[WARN] Attempt $attempt/$maxRetries failed with: ${lastError.getMessage}") |
| 58 | + sleepForAttempt(attempt) |
| 59 | + try { |
| 60 | + block |
| 61 | + } catch { |
| 62 | + case e: TransientHttpException if TransientStatusCodes.contains(e.status) => |
| 63 | + if (attempt >= maxRetries) throw e |
| 64 | + loop(attempt + 1, e) |
| 65 | + case e: IOException => |
| 66 | + if (attempt >= maxRetries) throw e |
| 67 | + loop(attempt + 1, e) |
| 68 | + // If it's an HTTP exception but not transient -> do not retry |
| 69 | + case e: HttpException => |
| 70 | + throw e |
| 71 | + // Non-HTTP failures are not retried |
| 72 | + case other: Throwable => |
| 73 | + throw other |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + // First attempt has no prior error; provide a dummy Throwable |
| 78 | + loop(attempt = 1, lastError = new RuntimeException("No attempts yet")) |
| 79 | + } |
| 80 | + |
| 81 | + /** Helper that opens a connection, and throws: |
| 82 | + * - TransientHttpException for transient statuses |
| 83 | + * - HttpStatusException for all other non-2xx statuses |
| 84 | + * |
| 85 | + * You can use this inside `withTransientHttpRetries` to ensure retry behavior is based on HTTP status. |
| 86 | + */ |
| 87 | + def openAndCheck(url: URL): HttpURLConnection = { |
| 88 | + val conn = url.openConnection().asInstanceOf[HttpURLConnection] |
| 89 | + |
| 90 | + // Trigger the request; for GET, calling getResponseCode initiates the connection |
| 91 | + val status = conn.getResponseCode |
| 92 | + if (status >= 200 && status <= 299) { |
| 93 | + conn |
| 94 | + } else { |
| 95 | + val msg = Option(conn.getResponseMessage).getOrElse("no message") |
| 96 | + // Ensure connection resources are released on failure |
| 97 | + conn.disconnect() |
| 98 | + if (TransientStatusCodes.contains(status)) throw TransientHttpException(status, url, msg) |
| 99 | + else throw HttpStatusException(status, url, msg) |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + /** Downloads a URL to a local file with retries for transient HTTP statuses and IOExceptions. */ |
| 104 | + def downloadWithRetries( |
| 105 | + url: URL, |
| 106 | + localFile: File, |
| 107 | + maxRetries: Int = 5, |
| 108 | + baseInterval: FiniteDuration = 500.millis, |
| 109 | + backoffFactor: Double = 2.0 |
| 110 | + ): Unit = { |
| 111 | + withTransientHttpRetries(maxRetries, baseInterval, backoffFactor) { |
| 112 | + val conn = openAndCheck(url) |
| 113 | + try { |
| 114 | + sbt.io.Using.bufferedInputStream(conn.getInputStream) { inputStream => |
| 115 | + sbt.IO.transfer(inputStream, localFile) |
| 116 | + } |
| 117 | + } finally conn.disconnect() |
| 118 | + } |
| 119 | + } |
| 120 | +} |
0 commit comments