@@ -926,7 +926,8 @@ CreateFileHandler(MemoryBuffer &FirstInput,
926926 " '" + FilesType + " ': invalid file type specified" );
927927}
928928
929- OffloadBundlerConfig::OffloadBundlerConfig () {
929+ OffloadBundlerConfig::OffloadBundlerConfig ()
930+ : CompressedBundleVersion(CompressedOffloadBundle::DefaultVersion) {
930931 if (llvm::compression::zstd::isAvailable ()) {
931932 CompressionFormat = llvm::compression::Format::Zstd;
932933 // Compression level 3 is usually sufficient for zstd since long distance
@@ -942,16 +943,13 @@ OffloadBundlerConfig::OffloadBundlerConfig() {
942943 llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_IGNORE_ENV_VAR" );
943944 if (IgnoreEnvVarOpt.has_value () && IgnoreEnvVarOpt.value () == " 1" )
944945 return ;
945-
946946 auto VerboseEnvVarOpt = llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_VERBOSE" );
947947 if (VerboseEnvVarOpt.has_value ())
948948 Verbose = VerboseEnvVarOpt.value () == " 1" ;
949-
950949 auto CompressEnvVarOpt =
951950 llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_COMPRESS" );
952951 if (CompressEnvVarOpt.has_value ())
953952 Compress = CompressEnvVarOpt.value () == " 1" ;
954-
955953 auto CompressionLevelEnvVarOpt =
956954 llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_COMPRESSION_LEVEL" );
957955 if (CompressionLevelEnvVarOpt.has_value ()) {
@@ -964,6 +962,26 @@ OffloadBundlerConfig::OffloadBundlerConfig() {
964962 << " Warning: Invalid value for OFFLOAD_BUNDLER_COMPRESSION_LEVEL: "
965963 << CompressionLevelStr.str () << " . Ignoring it.\n " ;
966964 }
965+ auto CompressedBundleFormatVersionOpt =
966+ llvm::sys::Process::GetEnv (" COMPRESSED_BUNDLE_FORMAT_VERSION" );
967+ if (CompressedBundleFormatVersionOpt.has_value ()) {
968+ llvm::StringRef VersionStr = CompressedBundleFormatVersionOpt.value ();
969+ uint16_t Version;
970+ if (!VersionStr.getAsInteger (10 , Version)) {
971+ if (Version >= 2 && Version <= 3 )
972+ CompressedBundleVersion = Version;
973+ else
974+ llvm::errs ()
975+ << " Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
976+ << VersionStr.str ()
977+ << " . Valid values are 2 or 3. Using default version "
978+ << CompressedBundleVersion << " .\n " ;
979+ } else
980+ llvm::errs ()
981+ << " Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
982+ << VersionStr.str () << " . Using default version "
983+ << CompressedBundleVersion << " .\n " ;
984+ }
967985}
968986
969987// Utility function to format numbers with commas
@@ -980,12 +998,11 @@ static std::string formatWithCommas(unsigned long long Value) {
980998llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
981999CompressedOffloadBundle::compress (llvm::compression::Params P,
9821000 const llvm::MemoryBuffer &Input,
983- bool Verbose) {
1001+ uint16_t Version, bool Verbose) {
9841002 if (!llvm::compression::zstd::isAvailable () &&
9851003 !llvm::compression::zlib::isAvailable ())
9861004 return createStringError (llvm::inconvertibleErrorCode (),
9871005 " Compression not supported" );
988-
9891006 llvm::Timer HashTimer (" Hash Calculation Timer" , " Hash calculation time" ,
9901007 ClangOffloadBundlerTimerGroup);
9911008 if (Verbose)
@@ -1002,7 +1019,6 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
10021019 auto BufferUint8 = llvm::ArrayRef<uint8_t >(
10031020 reinterpret_cast <const uint8_t *>(Input.getBuffer ().data ()),
10041021 Input.getBuffer ().size ());
1005-
10061022 llvm::Timer CompressTimer (" Compression Timer" , " Compression time" ,
10071023 ClangOffloadBundlerTimerGroup);
10081024 if (Verbose)
@@ -1012,22 +1028,54 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
10121028 CompressTimer.stopTimer ();
10131029
10141030 uint16_t CompressionMethod = static_cast <uint16_t >(P.format );
1015- uint32_t UncompressedSize = Input.getBuffer ().size ();
1016- uint32_t TotalFileSize = MagicNumber.size () + sizeof (TotalFileSize) +
1017- sizeof (Version) + sizeof (CompressionMethod) +
1018- sizeof (UncompressedSize) + sizeof (TruncatedHash) +
1019- CompressedBuffer.size ();
1031+
1032+ // Store sizes in 64-bit variables first
1033+ uint64_t UncompressedSize64 = Input.getBuffer ().size ();
1034+ uint64_t TotalFileSize64;
1035+
1036+ // Calculate total file size based on version
1037+ if (Version == 2 ) {
1038+ // For V2, ensure the sizes don't exceed 32-bit limit
1039+ if (UncompressedSize64 > std::numeric_limits<uint32_t >::max ())
1040+ return createStringError (llvm::inconvertibleErrorCode (),
1041+ " Uncompressed size exceeds version 2 limit" );
1042+ if ((MagicNumber.size () + sizeof (uint32_t ) + sizeof (Version) +
1043+ sizeof (CompressionMethod) + sizeof (uint32_t ) + sizeof (TruncatedHash) +
1044+ CompressedBuffer.size ()) > std::numeric_limits<uint32_t >::max ())
1045+ return createStringError (llvm::inconvertibleErrorCode (),
1046+ " Total file size exceeds version 2 limit" );
1047+
1048+ TotalFileSize64 = MagicNumber.size () + sizeof (uint32_t ) + sizeof (Version) +
1049+ sizeof (CompressionMethod) + sizeof (uint32_t ) +
1050+ sizeof (TruncatedHash) + CompressedBuffer.size ();
1051+ } else { // Version 3
1052+ TotalFileSize64 = MagicNumber.size () + sizeof (uint64_t ) + sizeof (Version) +
1053+ sizeof (CompressionMethod) + sizeof (uint64_t ) +
1054+ sizeof (TruncatedHash) + CompressedBuffer.size ();
1055+ }
10201056
10211057 SmallVector<char , 0 > FinalBuffer;
10221058 llvm::raw_svector_ostream OS (FinalBuffer);
10231059 OS << MagicNumber;
10241060 OS.write (reinterpret_cast <const char *>(&Version), sizeof (Version));
10251061 OS.write (reinterpret_cast <const char *>(&CompressionMethod),
10261062 sizeof (CompressionMethod));
1027- OS.write (reinterpret_cast <const char *>(&TotalFileSize),
1028- sizeof (TotalFileSize));
1029- OS.write (reinterpret_cast <const char *>(&UncompressedSize),
1030- sizeof (UncompressedSize));
1063+
1064+ // Write size fields according to version
1065+ if (Version == 2 ) {
1066+ uint32_t TotalFileSize32 = static_cast <uint32_t >(TotalFileSize64);
1067+ uint32_t UncompressedSize32 = static_cast <uint32_t >(UncompressedSize64);
1068+ OS.write (reinterpret_cast <const char *>(&TotalFileSize32),
1069+ sizeof (TotalFileSize32));
1070+ OS.write (reinterpret_cast <const char *>(&UncompressedSize32),
1071+ sizeof (UncompressedSize32));
1072+ } else { // Version 3
1073+ OS.write (reinterpret_cast <const char *>(&TotalFileSize64),
1074+ sizeof (TotalFileSize64));
1075+ OS.write (reinterpret_cast <const char *>(&UncompressedSize64),
1076+ sizeof (UncompressedSize64));
1077+ }
1078+
10311079 OS.write (reinterpret_cast <const char *>(&TruncatedHash),
10321080 sizeof (TruncatedHash));
10331081 OS.write (reinterpret_cast <const char *>(CompressedBuffer.data ()),
@@ -1037,18 +1085,17 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
10371085 auto MethodUsed =
10381086 P.format == llvm::compression::Format::Zstd ? " zstd" : " zlib" ;
10391087 double CompressionRate =
1040- static_cast <double >(UncompressedSize ) / CompressedBuffer.size ();
1088+ static_cast <double >(UncompressedSize64 ) / CompressedBuffer.size ();
10411089 double CompressionTimeSeconds = CompressTimer.getTotalTime ().getWallTime ();
10421090 double CompressionSpeedMBs =
1043- (UncompressedSize / (1024.0 * 1024.0 )) / CompressionTimeSeconds;
1044-
1091+ (UncompressedSize64 / (1024.0 * 1024.0 )) / CompressionTimeSeconds;
10451092 llvm::errs () << " Compressed bundle format version: " << Version << " \n "
10461093 << " Total file size (including headers): "
1047- << formatWithCommas (TotalFileSize ) << " bytes\n "
1094+ << formatWithCommas (TotalFileSize64 ) << " bytes\n "
10481095 << " Compression method used: " << MethodUsed << " \n "
10491096 << " Compression level: " << P.level << " \n "
10501097 << " Binary size before compression: "
1051- << formatWithCommas (UncompressedSize ) << " bytes\n "
1098+ << formatWithCommas (UncompressedSize64 ) << " bytes\n "
10521099 << " Binary size after compression: "
10531100 << formatWithCommas (CompressedBuffer.size ()) << " bytes\n "
10541101 << " Compression rate: "
@@ -1060,16 +1107,17 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
10601107 << " Truncated MD5 hash: "
10611108 << llvm::format_hex (TruncatedHash, 16 ) << " \n " ;
10621109 }
1110+
10631111 return llvm::MemoryBuffer::getMemBufferCopy (
10641112 llvm::StringRef (FinalBuffer.data (), FinalBuffer.size ()));
10651113}
10661114
10671115llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
10681116CompressedOffloadBundle::decompress (const llvm::MemoryBuffer &Input,
10691117 bool Verbose) {
1070-
10711118 StringRef Blob = Input.getBuffer ();
10721119
1120+ // Check minimum header size (using V1 as it's the smallest)
10731121 if (Blob.size () < V1HeaderSize)
10741122 return llvm::MemoryBuffer::getMemBufferCopy (Blob);
10751123
@@ -1082,31 +1130,56 @@ CompressedOffloadBundle::decompress(const llvm::MemoryBuffer &Input,
10821130
10831131 size_t CurrentOffset = MagicSize;
10841132
1133+ // Read version
10851134 uint16_t ThisVersion;
10861135 memcpy (&ThisVersion, Blob.data () + CurrentOffset, sizeof (uint16_t ));
10871136 CurrentOffset += VersionFieldSize;
10881137
1138+ // Verify header size based on version
1139+ if (ThisVersion >= 2 && ThisVersion <= 3 ) {
1140+ size_t RequiredSize = (ThisVersion == 2 ) ? V2HeaderSize : V3HeaderSize;
1141+ if (Blob.size () < RequiredSize)
1142+ return createStringError (inconvertibleErrorCode (),
1143+ " Compressed bundle header size too small" );
1144+ }
1145+
1146+ // Read compression method
10891147 uint16_t CompressionMethod;
10901148 memcpy (&CompressionMethod, Blob.data () + CurrentOffset, sizeof (uint16_t ));
10911149 CurrentOffset += MethodFieldSize;
10921150
1093- uint32_t TotalFileSize;
1151+ // Read total file size (version 2+)
1152+ uint64_t TotalFileSize = 0 ;
10941153 if (ThisVersion >= 2 ) {
1095- if (Blob.size () < V2HeaderSize)
1096- return createStringError (inconvertibleErrorCode (),
1097- " Compressed bundle header size too small" );
1098- memcpy (&TotalFileSize, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1099- CurrentOffset += FileSizeFieldSize;
1154+ if (ThisVersion == 2 ) {
1155+ uint32_t TotalFileSize32;
1156+ memcpy (&TotalFileSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1157+ TotalFileSize = TotalFileSize32;
1158+ CurrentOffset += FileSizeFieldSizeV2;
1159+ } else { // Version 3
1160+ memcpy (&TotalFileSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1161+ CurrentOffset += FileSizeFieldSizeV3;
1162+ }
11001163 }
11011164
1102- uint32_t UncompressedSize;
1103- memcpy (&UncompressedSize, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1104- CurrentOffset += UncompressedSizeFieldSize;
1165+ // Read uncompressed size
1166+ uint64_t UncompressedSize = 0 ;
1167+ if (ThisVersion <= 2 ) {
1168+ uint32_t UncompressedSize32;
1169+ memcpy (&UncompressedSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1170+ UncompressedSize = UncompressedSize32;
1171+ CurrentOffset += UncompressedSizeFieldSizeV2;
1172+ } else { // Version 3
1173+ memcpy (&UncompressedSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1174+ CurrentOffset += UncompressedSizeFieldSizeV3;
1175+ }
11051176
1177+ // Read hash
11061178 uint64_t StoredHash;
11071179 memcpy (&StoredHash, Blob.data () + CurrentOffset, sizeof (uint64_t ));
11081180 CurrentOffset += HashFieldSize;
11091181
1182+ // Determine compression format
11101183 llvm::compression::Format CompressionFormat;
11111184 if (CompressionMethod ==
11121185 static_cast <uint16_t >(llvm::compression::Format::Zlib))
@@ -1372,7 +1445,8 @@ Error OffloadBundler::BundleFiles() {
13721445 auto CompressionResult = CompressedOffloadBundle::compress (
13731446 {BundlerConfig.CompressionFormat , BundlerConfig.CompressionLevel ,
13741447 /* zstdEnableLdm=*/ true },
1375- *BufferMemory, BundlerConfig.Verbose );
1448+ *BufferMemory, BundlerConfig.CompressedBundleVersion ,
1449+ BundlerConfig.Verbose );
13761450 if (auto Error = CompressionResult.takeError ())
13771451 return Error;
13781452
0 commit comments