@@ -935,7 +935,8 @@ CreateFileHandler(MemoryBuffer &FirstInput,
935935 " '" + FilesType + " ': invalid file type specified" );
936936}
937937
938- OffloadBundlerConfig::OffloadBundlerConfig () {
938+ OffloadBundlerConfig::OffloadBundlerConfig ()
939+ : CompressedBundleVersion(CompressedOffloadBundle::DefaultVersion) {
939940 if (llvm::compression::zstd::isAvailable ()) {
940941 CompressionFormat = llvm::compression::Format::Zstd;
941942 // Compression level 3 is usually sufficient for zstd since long distance
@@ -951,16 +952,13 @@ OffloadBundlerConfig::OffloadBundlerConfig() {
951952 llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_IGNORE_ENV_VAR" );
952953 if (IgnoreEnvVarOpt.has_value () && IgnoreEnvVarOpt.value () == " 1" )
953954 return ;
954-
955955 auto VerboseEnvVarOpt = llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_VERBOSE" );
956956 if (VerboseEnvVarOpt.has_value ())
957957 Verbose = VerboseEnvVarOpt.value () == " 1" ;
958-
959958 auto CompressEnvVarOpt =
960959 llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_COMPRESS" );
961960 if (CompressEnvVarOpt.has_value ())
962961 Compress = CompressEnvVarOpt.value () == " 1" ;
963-
964962 auto CompressionLevelEnvVarOpt =
965963 llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_COMPRESSION_LEVEL" );
966964 if (CompressionLevelEnvVarOpt.has_value ()) {
@@ -973,6 +971,26 @@ OffloadBundlerConfig::OffloadBundlerConfig() {
973971 << " Warning: Invalid value for OFFLOAD_BUNDLER_COMPRESSION_LEVEL: "
974972 << CompressionLevelStr.str () << " . Ignoring it.\n " ;
975973 }
974+ auto CompressedBundleFormatVersionOpt =
975+ llvm::sys::Process::GetEnv (" COMPRESSED_BUNDLE_FORMAT_VERSION" );
976+ if (CompressedBundleFormatVersionOpt.has_value ()) {
977+ llvm::StringRef VersionStr = CompressedBundleFormatVersionOpt.value ();
978+ uint16_t Version;
979+ if (!VersionStr.getAsInteger (10 , Version)) {
980+ if (Version >= 2 && Version <= 3 )
981+ CompressedBundleVersion = Version;
982+ else
983+ llvm::errs ()
984+ << " Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
985+ << VersionStr.str ()
986+ << " . Valid values are 2 or 3. Using default version "
987+ << CompressedBundleVersion << " .\n " ;
988+ } else
989+ llvm::errs ()
990+ << " Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
991+ << VersionStr.str () << " . Using default version "
992+ << CompressedBundleVersion << " .\n " ;
993+ }
976994}
977995
978996// Utility function to format numbers with commas
@@ -989,12 +1007,11 @@ static std::string formatWithCommas(unsigned long long Value) {
9891007llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
9901008CompressedOffloadBundle::compress (llvm::compression::Params P,
9911009 const llvm::MemoryBuffer &Input,
992- bool Verbose) {
1010+ uint16_t Version, bool Verbose) {
9931011 if (!llvm::compression::zstd::isAvailable () &&
9941012 !llvm::compression::zlib::isAvailable ())
9951013 return createStringError (llvm::inconvertibleErrorCode (),
9961014 " Compression not supported" );
997-
9981015 llvm::Timer HashTimer (" Hash Calculation Timer" , " Hash calculation time" ,
9991016 *ClangOffloadBundlerTimerGroup);
10001017 if (Verbose)
@@ -1011,7 +1028,6 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
10111028 auto BufferUint8 = llvm::ArrayRef<uint8_t >(
10121029 reinterpret_cast <const uint8_t *>(Input.getBuffer ().data ()),
10131030 Input.getBuffer ().size ());
1014-
10151031 llvm::Timer CompressTimer (" Compression Timer" , " Compression time" ,
10161032 *ClangOffloadBundlerTimerGroup);
10171033 if (Verbose)
@@ -1021,22 +1037,54 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
10211037 CompressTimer.stopTimer ();
10221038
10231039 uint16_t CompressionMethod = static_cast <uint16_t >(P.format );
1024- uint32_t UncompressedSize = Input.getBuffer ().size ();
1025- uint32_t TotalFileSize = MagicNumber.size () + sizeof (TotalFileSize) +
1026- sizeof (Version) + sizeof (CompressionMethod) +
1027- sizeof (UncompressedSize) + sizeof (TruncatedHash) +
1028- CompressedBuffer.size ();
1040+
1041+ // Store sizes in 64-bit variables first
1042+ uint64_t UncompressedSize64 = Input.getBuffer ().size ();
1043+ uint64_t TotalFileSize64;
1044+
1045+ // Calculate total file size based on version
1046+ if (Version == 2 ) {
1047+ // For V2, ensure the sizes don't exceed 32-bit limit
1048+ if (UncompressedSize64 > std::numeric_limits<uint32_t >::max ())
1049+ return createStringError (llvm::inconvertibleErrorCode (),
1050+ " Uncompressed size exceeds version 2 limit" );
1051+ if ((MagicNumber.size () + sizeof (uint32_t ) + sizeof (Version) +
1052+ sizeof (CompressionMethod) + sizeof (uint32_t ) + sizeof (TruncatedHash) +
1053+ CompressedBuffer.size ()) > std::numeric_limits<uint32_t >::max ())
1054+ return createStringError (llvm::inconvertibleErrorCode (),
1055+ " Total file size exceeds version 2 limit" );
1056+
1057+ TotalFileSize64 = MagicNumber.size () + sizeof (uint32_t ) + sizeof (Version) +
1058+ sizeof (CompressionMethod) + sizeof (uint32_t ) +
1059+ sizeof (TruncatedHash) + CompressedBuffer.size ();
1060+ } else { // Version 3
1061+ TotalFileSize64 = MagicNumber.size () + sizeof (uint64_t ) + sizeof (Version) +
1062+ sizeof (CompressionMethod) + sizeof (uint64_t ) +
1063+ sizeof (TruncatedHash) + CompressedBuffer.size ();
1064+ }
10291065
10301066 SmallVector<char , 0 > FinalBuffer;
10311067 llvm::raw_svector_ostream OS (FinalBuffer);
10321068 OS << MagicNumber;
10331069 OS.write (reinterpret_cast <const char *>(&Version), sizeof (Version));
10341070 OS.write (reinterpret_cast <const char *>(&CompressionMethod),
10351071 sizeof (CompressionMethod));
1036- OS.write (reinterpret_cast <const char *>(&TotalFileSize),
1037- sizeof (TotalFileSize));
1038- OS.write (reinterpret_cast <const char *>(&UncompressedSize),
1039- sizeof (UncompressedSize));
1072+
1073+ // Write size fields according to version
1074+ if (Version == 2 ) {
1075+ uint32_t TotalFileSize32 = static_cast <uint32_t >(TotalFileSize64);
1076+ uint32_t UncompressedSize32 = static_cast <uint32_t >(UncompressedSize64);
1077+ OS.write (reinterpret_cast <const char *>(&TotalFileSize32),
1078+ sizeof (TotalFileSize32));
1079+ OS.write (reinterpret_cast <const char *>(&UncompressedSize32),
1080+ sizeof (UncompressedSize32));
1081+ } else { // Version 3
1082+ OS.write (reinterpret_cast <const char *>(&TotalFileSize64),
1083+ sizeof (TotalFileSize64));
1084+ OS.write (reinterpret_cast <const char *>(&UncompressedSize64),
1085+ sizeof (UncompressedSize64));
1086+ }
1087+
10401088 OS.write (reinterpret_cast <const char *>(&TruncatedHash),
10411089 sizeof (TruncatedHash));
10421090 OS.write (reinterpret_cast <const char *>(CompressedBuffer.data ()),
@@ -1046,18 +1094,17 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
10461094 auto MethodUsed =
10471095 P.format == llvm::compression::Format::Zstd ? " zstd" : " zlib" ;
10481096 double CompressionRate =
1049- static_cast <double >(UncompressedSize ) / CompressedBuffer.size ();
1097+ static_cast <double >(UncompressedSize64 ) / CompressedBuffer.size ();
10501098 double CompressionTimeSeconds = CompressTimer.getTotalTime ().getWallTime ();
10511099 double CompressionSpeedMBs =
1052- (UncompressedSize / (1024.0 * 1024.0 )) / CompressionTimeSeconds;
1053-
1100+ (UncompressedSize64 / (1024.0 * 1024.0 )) / CompressionTimeSeconds;
10541101 llvm::errs () << " Compressed bundle format version: " << Version << " \n "
10551102 << " Total file size (including headers): "
1056- << formatWithCommas (TotalFileSize ) << " bytes\n "
1103+ << formatWithCommas (TotalFileSize64 ) << " bytes\n "
10571104 << " Compression method used: " << MethodUsed << " \n "
10581105 << " Compression level: " << P.level << " \n "
10591106 << " Binary size before compression: "
1060- << formatWithCommas (UncompressedSize ) << " bytes\n "
1107+ << formatWithCommas (UncompressedSize64 ) << " bytes\n "
10611108 << " Binary size after compression: "
10621109 << formatWithCommas (CompressedBuffer.size ()) << " bytes\n "
10631110 << " Compression rate: "
@@ -1069,16 +1116,17 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
10691116 << " Truncated MD5 hash: "
10701117 << llvm::format_hex (TruncatedHash, 16 ) << " \n " ;
10711118 }
1119+
10721120 return llvm::MemoryBuffer::getMemBufferCopy (
10731121 llvm::StringRef (FinalBuffer.data (), FinalBuffer.size ()));
10741122}
10751123
10761124llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
10771125CompressedOffloadBundle::decompress (const llvm::MemoryBuffer &Input,
10781126 bool Verbose) {
1079-
10801127 StringRef Blob = Input.getBuffer ();
10811128
1129+ // Check minimum header size (using V1 as it's the smallest)
10821130 if (Blob.size () < V1HeaderSize)
10831131 return llvm::MemoryBuffer::getMemBufferCopy (Blob);
10841132
@@ -1091,31 +1139,56 @@ CompressedOffloadBundle::decompress(const llvm::MemoryBuffer &Input,
10911139
10921140 size_t CurrentOffset = MagicSize;
10931141
1142+ // Read version
10941143 uint16_t ThisVersion;
10951144 memcpy (&ThisVersion, Blob.data () + CurrentOffset, sizeof (uint16_t ));
10961145 CurrentOffset += VersionFieldSize;
10971146
1147+ // Verify header size based on version
1148+ if (ThisVersion >= 2 && ThisVersion <= 3 ) {
1149+ size_t RequiredSize = (ThisVersion == 2 ) ? V2HeaderSize : V3HeaderSize;
1150+ if (Blob.size () < RequiredSize)
1151+ return createStringError (inconvertibleErrorCode (),
1152+ " Compressed bundle header size too small" );
1153+ }
1154+
1155+ // Read compression method
10981156 uint16_t CompressionMethod;
10991157 memcpy (&CompressionMethod, Blob.data () + CurrentOffset, sizeof (uint16_t ));
11001158 CurrentOffset += MethodFieldSize;
11011159
1102- uint32_t TotalFileSize;
1160+ // Read total file size (version 2+)
1161+ uint64_t TotalFileSize = 0 ;
11031162 if (ThisVersion >= 2 ) {
1104- if (Blob.size () < V2HeaderSize)
1105- return createStringError (inconvertibleErrorCode (),
1106- " Compressed bundle header size too small" );
1107- memcpy (&TotalFileSize, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1108- CurrentOffset += FileSizeFieldSize;
1163+ if (ThisVersion == 2 ) {
1164+ uint32_t TotalFileSize32;
1165+ memcpy (&TotalFileSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1166+ TotalFileSize = TotalFileSize32;
1167+ CurrentOffset += FileSizeFieldSizeV2;
1168+ } else { // Version 3
1169+ memcpy (&TotalFileSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1170+ CurrentOffset += FileSizeFieldSizeV3;
1171+ }
11091172 }
11101173
1111- uint32_t UncompressedSize;
1112- memcpy (&UncompressedSize, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1113- CurrentOffset += UncompressedSizeFieldSize;
1174+ // Read uncompressed size
1175+ uint64_t UncompressedSize = 0 ;
1176+ if (ThisVersion <= 2 ) {
1177+ uint32_t UncompressedSize32;
1178+ memcpy (&UncompressedSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1179+ UncompressedSize = UncompressedSize32;
1180+ CurrentOffset += UncompressedSizeFieldSizeV2;
1181+ } else { // Version 3
1182+ memcpy (&UncompressedSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1183+ CurrentOffset += UncompressedSizeFieldSizeV3;
1184+ }
11141185
1186+ // Read hash
11151187 uint64_t StoredHash;
11161188 memcpy (&StoredHash, Blob.data () + CurrentOffset, sizeof (uint64_t ));
11171189 CurrentOffset += HashFieldSize;
11181190
1191+ // Determine compression format
11191192 llvm::compression::Format CompressionFormat;
11201193 if (CompressionMethod ==
11211194 static_cast <uint16_t >(llvm::compression::Format::Zlib))
@@ -1381,7 +1454,8 @@ Error OffloadBundler::BundleFiles() {
13811454 auto CompressionResult = CompressedOffloadBundle::compress (
13821455 {BundlerConfig.CompressionFormat , BundlerConfig.CompressionLevel ,
13831456 /* zstdEnableLdm=*/ true },
1384- *BufferMemory, BundlerConfig.Verbose );
1457+ *BufferMemory, BundlerConfig.CompressedBundleVersion ,
1458+ BundlerConfig.Verbose );
13851459 if (auto Error = CompressionResult.takeError ())
13861460 return Error;
13871461
0 commit comments